diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_optim.py | 29 |
1 files changed, 26 insertions, 3 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py index fc2456f..5464043 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -34,11 +34,13 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW) str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) +str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False)) str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) @@ -46,28 +48,34 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_ str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True)) str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) +str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True)) str2statenames = {} str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] +str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['rmsprop'] = [('square_avg', 'state1')] +str2statenames['adagrad'] = [('sum', 'state1')] str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] +str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] +str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] +optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -82,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) if gtype == torch.float32: - atol, rtol = 1e-6, 1e-5 + atol, rtol = 2e-6, 1e-5 else: atol, rtol = 1e-4, 1e-3 @@ -197,7 +205,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] +optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -384,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): #assert s < 3.9 + +def test_str_betas(): + betas = (0.80, 0.95) + strbetas = '(0.80, 0.95)' + + layer = torch.nn.Linear(10, 10) + + base = bnb.optim.Adam(layer.parameters(), betas=betas) + strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas) + assert base.defaults['betas'][0] == 0.8 + assert base.defaults['betas'][1] == 0.95 + assert strbase.defaults['betas'][0] == 0.8 + assert strbase.defaults['betas'][1] == 0.95 + + |