From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- tests/test_optim.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) (limited to 'tests') diff --git a/tests/test_optim.py b/tests/test_optim.py index 4d67b08..fc2456f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() + bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) @@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype): else: atol, rtol = 1e-4, 1e-3 + original_p2 = p2[mask].clone() + for i in range(50): g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 @@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype): p2.grad = g2 p3.grad = g3 + if i > 30 and i % 10 == 0: + g1.data[mask] = 0.0 + g2.data[mask] = 0.0 + p1.grad = g1 + p2.grad = g2 + original_p1 = p1[mask].clone() + original_p2 = p2[mask].clone() + og_s1 = adam2.state[p2]['state1'][mask].clone() + og_s2 = adam2.state[p2]['state2'][mask].clone() + og_s11 = adam2.state[p1]['state1'][mask].clone() + og_s21 = adam2.state[p1]['state2'][mask].clone() + adam2.step() assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8 + if i > 30 and i % 10 == 0: + torch.testing.assert_allclose(original_p2, p2[mask]) + torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1) + torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2) + assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel() + assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0 + assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0 + + dim1 = [1024] -- cgit v1.2.3