From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- tests/test_optim.py | 71 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 16 deletions(-) (limited to 'tests/test_optim.py') diff --git a/tests/test_optim.py b/tests/test_optim.py index b84425e..8e12761 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [ ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), ] -str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["momentum8bit"] = [ + ("momentum_buffer", "state1", "qmap1", "max1") +] str2statenames["momentum8bit_blockwise"] = [ ("momentum_buffer", "state1", "qmap1", "absmax1") ] str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] -str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] +str2statenames["rmsprop8bit_blockwise"] = [ + ("square_avg", "state1", "qmap1", "absmax1") +] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) + bnb.optim.GlobalOptimManager.get_instance().override_config( + p3, "optim_bits", 8 + ) - bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) + bnb.optim.GlobalOptimManager.get_instance().register_parameters( + [p1, p2, p3] + ) p1 = p1.cuda() p2 = p2.cuda() p3 = p3.cuda() @@ -245,7 +255,9 @@ optimizer_names = [ "rmsprop8bit_blockwise", ] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) - torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) + torch.testing.assert_allclose( + raws1cpy, bnb_optimizer.state[p2][name2] + ) + torch.testing.assert_allclose( + qmap1, bnb_optimizer.state[p2][qmap] + ) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): num_not_close = ( torch.isclose( - torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + torch_optimizer.state[p1][name1], + s1, + atol=atol, + rtol=rtol, ) == 0 ) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + torch.testing.assert_allclose( + p1, p2.float(), atol=patol, rtol=prtol + ) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error @@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097] gtype = [torch.float32] optim_bits = [32, 8] values = list(product(dim1, dim2, gtype, optim_bits)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) + for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) @@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): p2 = p1.clone() adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) adam2 = bnb.optim.Adam( - [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5 + [p2], + lr, + (beta1, beta2), + eps, + optim_bits=optim_bits, + percentile_clipping=5, ) gnorm_vec = torch.zeros(100).cuda() @@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): for i in range(50): step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( + 0.01 * i + ) g2 = g1.clone() p2.grad = g2 @@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): elif optim_bits == 8: torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_allclose( - adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3 + adam1.state[p1]["state1"], + adam2.state[p2]["state1"], + atol=2, + rtol=1e-3, ) torch.testing.assert_allclose( - adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3 + adam1.state[p1]["state2"], + adam2.state[p2]["state2"], + atol=2, + rtol=1e-3, ) adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) @@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16] # optimizer_names = ['lars_apex', 'lars8bit'] optimizer_names = ["adam8bit_blockwise"] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) -- cgit v1.2.3