diff options
author | Tim Dettmers <dettmers@cs.washington.edu> | 2021-10-20 18:37:44 -0700 |
---|---|---|
committer | Tim Dettmers <dettmers@cs.washington.edu> | 2021-10-20 18:37:44 -0700 |
commit | bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 (patch) | |
tree | a01ed945c348027480a9d0cefb6698dfd7259fb1 /tests | |
parent | 8400b58cbbc06e0a434cfa71f76c2efd713473fc (diff) |
Initial plumbing for skip_zeros.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_optim.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py index 4d67b08..fc2456f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() + bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) @@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype): else: atol, rtol = 1e-4, 1e-3 + original_p2 = p2[mask].clone() + for i in range(50): g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 @@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype): p2.grad = g2 p3.grad = g3 + if i > 30 and i % 10 == 0: + g1.data[mask] = 0.0 + g2.data[mask] = 0.0 + p1.grad = g1 + p2.grad = g2 + original_p1 = p1[mask].clone() + original_p2 = p2[mask].clone() + og_s1 = adam2.state[p2]['state1'][mask].clone() + og_s2 = adam2.state[p2]['state2'][mask].clone() + og_s11 = adam2.state[p1]['state1'][mask].clone() + og_s21 = adam2.state[p1]['state2'][mask].clone() + adam2.step() assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8 + if i > 30 and i % 10 == 0: + torch.testing.assert_allclose(original_p2, p2[mask]) + torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1) + torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2) + assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel() + assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0 + assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0 + + dim1 = [1024] |