summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_optim.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py
index 4d67b08..fc2456f 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
+ bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
@@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype):
else:
atol, rtol = 1e-4, 1e-3
+ original_p2 = p2[mask].clone()
+
for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
@@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype):
p2.grad = g2
p3.grad = g3
+ if i > 30 and i % 10 == 0:
+ g1.data[mask] = 0.0
+ g2.data[mask] = 0.0
+ p1.grad = g1
+ p2.grad = g2
+ original_p1 = p1[mask].clone()
+ original_p2 = p2[mask].clone()
+ og_s1 = adam2.state[p2]['state1'][mask].clone()
+ og_s2 = adam2.state[p2]['state2'][mask].clone()
+ og_s11 = adam2.state[p1]['state1'][mask].clone()
+ og_s21 = adam2.state[p1]['state2'][mask].clone()
+
adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8
+ if i > 30 and i % 10 == 0:
+ torch.testing.assert_allclose(original_p2, p2[mask])
+ torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
+ torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
+ assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
+ assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
+ assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
+
+
dim1 = [1024]