From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- tests/test_optim.py | 397 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 265 insertions(+), 132 deletions(-) (limited to 'tests/test_optim.py') diff --git a/tests/test_optim.py b/tests/test_optim.py index b173eaa..b84425e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,81 +1,132 @@ +import ctypes import os -import time import shutil +import time import uuid +from itertools import product +from os.path import join + import pytest -import ctypes import torch + import bitsandbytes as bnb import bitsandbytes.functional as F -from os.path import join -from itertools import product - -#import apex +# import apex k = 20 + def get_temp_dir(): - path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) + path = "/tmp/autoswap/{0}".format(str(uuid.uuid4())) os.makedirs(path, exist_ok=True) return path + def rm_path(path): shutil.rmtree(path) + str2optimizers = {} -str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) -#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) -str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) -#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) -#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) - -str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) -#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) -str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) -#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) -str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) -str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) -#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) -str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) - -str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) -str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) +str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) +# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) +# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers["momentum_pytorch"] = ( + None, + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + bnb.optim.Adam, +) +# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) +# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) + +str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) +# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["momentum"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["lars"] = ( + lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9), +) +# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) +str2optimizers["rmsprop"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["adam8bit"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), +) +str2optimizers["momentum8bit"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["rmsprop8bit"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), +) +# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) +str2optimizers["lars8bit"] = ( + lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9), +) + +str2optimizers["adam8bit_blockwise"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), +) +str2optimizers["momentum8bit_blockwise"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), +) +str2optimizers["rmsprop8bit_blockwise"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), +) str2statenames = {} -str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] -str2statenames['momentum'] = [('momentum_buffer', 'state1')] -str2statenames['lars'] = [('momentum_buffer', 'state1')] -str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] -str2statenames['rmsprop'] = [('square_avg', 'state1')] -str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] -str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] -str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] -str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] -str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] -str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] -str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] -str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] +str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["momentum"] = [("momentum_buffer", "state1")] +str2statenames["lars"] = [("momentum_buffer", "state1")] +str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["rmsprop"] = [("square_avg", "state1")] +str2statenames["adam8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1"), + ("exp_avg_sq", "state2", "qmap2", "max2"), +] +str2statenames["lamb8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1"), + ("exp_avg_sq", "state2", "qmap2", "max2"), +] +str2statenames["adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [ + ("momentum_buffer", "state1", "qmap1", "absmax1") +] +str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] -values = list(product(dim1,dim2, gtype, optimizer_names)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() - torch_optimizer = str2optimizers[optim_name][0]([p1]) bnb_optimizer = str2optimizers[optim_name][1]([p2]) @@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): else: atol, rtol = 1e-4, 1e-3 - for i in range(k): - g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): torch_optimizer.step() for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) + torch.testing.assert_allclose( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + ) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) - if i % (k//5) == 0 and i > 0: + if i % (k // 5) == 0 and i > 0: path = get_temp_dir() - torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) + torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) del bnb_optimizer bnb_optimizer = None bnb_optimizer = str2optimizers[optim_name][1]([p2]) - bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) + bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) + torch.testing.assert_allclose( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + ) if gtype == torch.float16: # the adam buffers should also be close because they are 32-bit @@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): p1.data = p1.data.half().float() p2.copy_(p1.data) torch.testing.assert_allclose(p1.half(), p2) - if optim_name in ['lars', 'lamb']: - assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0 + if optim_name in ["lars", "lamb"]: + assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 + dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -values = list(product(dim1,dim2, gtype)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, gtype)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) def test_global_config(dim1, dim2, gtype): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 - p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 - p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 mask = torch.rand_like(p2) < 0.1 beta1 = 0.9 beta2 = 0.999 @@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) + bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) p1 = p1.cuda() @@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype): atol, rtol = 1e-4, 1e-3 for i in range(50): - g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 - g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 - g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 p1.grad = g1 p2.grad = g2 p3.grad = g3 adam2.step() - assert adam2.state[p3]['state1'].dtype == torch.uint8 - assert adam2.state[p3]['state2'].dtype == torch.uint8 - + assert adam2.state[p3]["state1"].dtype == torch.uint8 + assert adam2.state[p3]["state2"].dtype == torch.uint8 dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] -values = list(product(dim1,dim2, gtype, optimizer_names)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +optimizer_names = [ + "adam8bit", + "momentum8bit", + "rmsprop8bit", + "adam8bit_blockwise", + "lamb8bit", + "lars8bit", + "momentum8bit_blockwise", + "rmsprop8bit_blockwise", +] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() blocksize = 2048 @@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors = [] for i in range(50): - g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: - #print(bnb_optimizer.state[p2][max_val], name1) - if 'blockwise' in optim_name: - s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) + # print(bnb_optimizer.state[p2][max_val], name1) + if "blockwise" in optim_name: + s1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, + ) else: - s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) - num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 + s1 = F.dequantize( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + ) + num_not_close = ( + torch.isclose( + torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + ) + == 0 + ) assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) - err = torch.abs(p1-p2) - relerr = err/torch.abs(p1) + err = torch.abs(p1 - p2) + relerr = err / torch.abs(p1) assert err.mean() < 0.0001 assert relerr.mean() < 0.001 @@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors.append(relerr.mean().item()) if i % 10 == 0 and i > 0: - for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + for (name1, name2, qmap, max_val), s in zip( + str2statenames[optim_name], dequant_states + ): s1cpy = s.clone() raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() path = get_temp_dir() - torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) + torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) del bnb_optimizer bnb_optimizer = None bnb_optimizer = str2optimizers[optim_name][1]([p2]) - bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) + bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) - if 'blockwise' in optim_name: - s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) + if "blockwise" in optim_name: + s1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, + ) else: - s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) + s1 = F.dequantize( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + ) torch.testing.assert_allclose(s1cpy, s1) - num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 + num_not_close = ( + torch.isclose( + torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + ) + == 0 + ) assert num_not_close.sum().item() < 20 torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) @@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) torch.testing.assert_allclose(p1.to(gtype), p2) - for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + for (name1, name2, qmap, max_val), s in zip( + str2statenames[optim_name], dequant_states + ): torch_optimizer.state[p1][name1].copy_(s.data) - #print(sum(errors)/len(errors)) - #print(sum(relerrors)/len(relerrors)) - + # print(sum(errors)/len(errors)) + # print(sum(relerrors)/len(relerrors)) dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32] optim_bits = [32, 8] -values = list(product(dim1,dim2, gtype, optim_bits)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, gtype, optim_bits)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 beta1 = 0.9 beta2 = 0.999 lr = 0.001 @@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): p1 = p1.cuda() p2 = p1.clone() adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) - adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) + adam2 = bnb.optim.Adam( + [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5 + ) gnorm_vec = torch.zeros(100).cuda() step = 0 for i in range(50): step += 1 - g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) g2 = g1.clone() p2.grad = g2 - current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) - g1 = (g1.float()*gnorm_scale).to(gtype) + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( + g1, gnorm_vec, step, 5 + ) + g1 = (g1.float() * gnorm_scale).to(gtype) p1.grad = g1 adam1.step() @@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state if optim_bits == 32: torch.testing.assert_allclose(p1, p2) - torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4) - torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4) + torch.testing.assert_allclose( + adam1.state[p1]["state1"], + adam2.state[p2]["state1"], + atol=5e-5, + rtol=1e-4, + ) + torch.testing.assert_allclose( + adam1.state[p1]["state2"], + adam2.state[p2]["state2"], + atol=5e-5, + rtol=1e-4, + ) elif optim_bits == 8: torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3) - torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3) - adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1']) - adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2']) + torch.testing.assert_allclose( + adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3 + ) + torch.testing.assert_allclose( + adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3 + ) + adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) + adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) if i % 10 == 0 and i > 0: path = get_temp_dir() - torch.save(adam2.state_dict(),join(path, 'opt.pt')) + torch.save(adam2.state_dict(), join(path, "opt.pt")) del adam2 adam2 = None - adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) - adam2.load_state_dict(torch.load(join(path, 'opt.pt'))) - - + adam2 = bnb.optim.Adam( + [p2], + lr, + (beta1, beta2), + eps, + optim_bits=optim_bits, + percentile_clipping=5, + ) + adam2.load_state_dict(torch.load(join(path, "opt.pt"))) dim1 = [4096] dim2 = [4096] gtype = [torch.float32, torch.float16] -#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] -#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] -#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] -#optimizer_names = ['lamb_apex', 'lamb8bit'] -#optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ['adam8bit_blockwise'] -values = list(product(dim1,dim2, gtype, optimizer_names)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] +# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] +# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] +# optimizer_names = ['lamb_apex', 'lamb8bit'] +# optimizer_names = ['lars_apex', 'lars8bit'] +optimizer_names = ["adam8bit_blockwise"] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 bnb_optimizer = str2optimizers[optim_name][1]([p1]) - g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g for i in range(k): - if i == k//5: + if i == k // 5: # 100 iterations for burn-in torch.cuda.synchronize() t0 = time.time() @@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch.cuda.synchronize() - s = time.time()-t0 - print('') - params = (k-k//5)*dim1*dim2 - print(optim_name, gtype, s/params) - #assert s < 3.9 - - + s = time.time() - t0 + print("") + params = (k - k // 5) * dim1 * dim2 + print(optim_name, gtype, s / params) + # assert s < 3.9 -- cgit v1.2.3