diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2021-10-05 19:16:20 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2021-10-05 19:16:20 -0700 |
commit | 7439924891496025edf60c9da6a782f362a50c70 (patch) | |
tree | 90476984d2c267f89232577a2ea40eb172387475 /tests |
Initial commit
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_functional.py | 213 | ||||
-rw-r--r-- | tests/test_optim.py | 362 |
2 files changed, 575 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100644 index 0000000..2a7d308 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import bitsandbytes as bnb + +from itertools import product + +from bitsandbytes import functional as F + +def setup(): + pass + +def teardown(): + pass + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half']) +def test_estimate_quantiles(dtype): + A = torch.rand(1024, 1024, device='cuda') + A = A.to(dtype) + code = F.estimate_quantiles(A) + + percs = torch.linspace(1/512, 511/512, 256, device=A.device) + torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) + + A = torch.randn(1024, 1024, device='cuda') + A = A.to(dtype) + code = F.estimate_quantiles(A) + + quantiles = torch.quantile(A.float(), percs) + diff = torch.abs(code-quantiles) + assert (diff > 5e-02).sum().item() == 0 + + +def test_quantile_quantization(): + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1-A2).mean().item() + assert diff < 0.0075 + + A1 = torch.rand(1024, 1024, device='cuda') + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1-A2).mean().item() + torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) + assert diff < 0.001 + + +def test_dynamic_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1-A2) + reldiff = diff/torch.abs(A1+1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diff.mean().item() < 0.0135 + print(sum(diffs)/len(diffs)) + print(sum(reldiffs)/len(reldiffs)) + + for i in range(100): + A1 = torch.rand(1024, 1024, device='cuda') + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1-A2).mean().item() + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + assert diff < 0.004 + + +def test_dynamic_blockwise_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2) + reldiff = diff/torch.abs(A1+1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + print(sum(diffs)/len(diffs)) + print(sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device='cuda') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2).mean().item() + assert diff < 0.0033 + diffs.append(diff) + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + #print(sum(diffs)/len(diffs)) + +def test_dynamic_blockwise_stochastic_quantization(): + diffs = [] + reldiffs = [] + rand = torch.rand(1024).cuda() + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + C1, S1 = F.quantize_blockwise(A1, rand=rand) + C2, S2 = F.quantize_blockwise(A1) + # a maximunm distance of quantized values of 1 + torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) + fraction_smaller = (C1<C2).float().sum()/C1.numel() + fraction_larger = (C1>C2).float().sum()/C1.numel() + torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) + + + +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half']) +def test_percentile_clipping(gtype): + gnorm_vec1 = torch.zeros(100, device='cuda') + gnorm_vec2 = torch.zeros(100, device='cuda') + n = 4 + step = 0 + percentile=5 + for i in range(1000): + step += 1 + g = torch.randn(n, n, dtype=gtype, device='cuda') + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1 + + gnorm2 = torch.norm(g.float()) + if step == 1: + gnorm_vec1[:] = gnorm2 + else: + gnorm_vec1[step % 100] = gnorm2 + + vals, idx = torch.sort(gnorm_vec1) + clip1 = vals[percentile] + + torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_allclose(clip1, clip2) + torch.testing.assert_allclose(gnorm1, gnorm2) + + +def test_stable_embedding(): + layer = bnb.nn.StableEmbedding(1024, 1024) + layer.reset_parameters() + + +def test_dynamic_blockwise_quantization_cpu(): + #A1 = torch.randn(1024, 1024, device='cpu') + #code = F.create_dynamic_map() + #for i in range(1000): + # C, S = F.quantize_blockwise(A1, code=code) + # A2 = F.dequantize_blockwise(C, S) + + for i in range(10): + # equivalence with GPU blockwise quantization + A1 = torch.randn(1024, 1024, device='cpu') + C1, S1 = F.quantize_blockwise(A1) + C2, S2 = F.quantize_blockwise(A1.cuda()) + torch.testing.assert_allclose(S1[0], S2[0].cpu()) + # there seems to be some issues with precision in CUDA vs CPU + # not all elements are usually close, with couple off elements in a million + idx = torch.isclose(C1, C2.cpu()) + assert (idx==0).sum().item() < 15 + + + diffs = [] + reldiffs = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device='cpu') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2) + reldiff = diff/torch.abs(A1+1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + #print(sum(diffs)/len(diffs)) + #print(sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(10): + A1 = torch.rand(1024, 1024, device='cpu') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2).mean().item() + assert diff < 0.0033 + diffs.append(diff) + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + #print(sum(diffs)/len(diffs)) + + +def test_histogram(): + dim1, dim2 = 32, 32 + source = torch.rand(dim1, dim2, device='cuda') + idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() + idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() + histogram1 = torch.zeros((256, 256)).cuda() + histogram2 = torch.zeros((256, 256)).cuda() + + F.histogram_scatter_add_2d(histogram2, idx1, idx2, source) + + for i in range(dim1): + for j in range(dim2): + histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j] + + torch.testing.assert_allclose(histogram1, histogram2) + torch.testing.assert_allclose(histogram1.sum(), source.sum()) diff --git a/tests/test_optim.py b/tests/test_optim.py new file mode 100644 index 0000000..4d67b08 --- /dev/null +++ b/tests/test_optim.py @@ -0,0 +1,362 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import os +import time +import shutil +import uuid +import pytest +import ctypes +import torch +import bitsandbytes as bnb +import bitsandbytes.functional as F + +from os.path import join +from itertools import product + +import apex + +def get_temp_dir(): + path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) + os.makedirs(path, exist_ok=True) + return path + +def rm_path(path): + shutil.rmtree(path) + +str2optimizers = {} +str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) +str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) +str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) + +str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) +str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) +str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) +str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) +str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) +str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) +str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) +str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) +str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) + +str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) +str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) + +str2statenames = {} +str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] +str2statenames['momentum'] = [('momentum_buffer', 'state1')] +str2statenames['lars'] = [('momentum_buffer', 'state1')] +str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] +str2statenames['rmsprop'] = [('square_avg', 'state1')] +str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] +str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] +str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] +str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] +str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] +str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] +str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] +str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] + +dim1 = [1024] +dim2 = [32, 1024, 4097, 1] +gtype = [torch.float32, torch.float16] +optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] +values = list(product(dim1,dim2, gtype, optimizer_names)) +names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +def test_optimizer32bit(dim1, dim2, gtype, optim_name): + if dim1 == 1 and dim2 == 1: return + p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + p2 = p1.clone() + p1 = p1.float() + + + torch_optimizer = str2optimizers[optim_name][0]([p1]) + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + + if gtype == torch.float32: + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 1e-4, 1e-3 + + + for i in range(50): + g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + p1.grad = g.clone().float() + p2.grad = g.clone() + + bnb_optimizer.step() + torch_optimizer.step() + + for name1, name2 in str2statenames[optim_name]: + torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) + + torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + + if i % 10 == 0 and i > 0: + path = get_temp_dir() + torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) + del bnb_optimizer + bnb_optimizer = None + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) + rm_path(path) + torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + for name1, name2 in str2statenames[optim_name]: + torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) + + if gtype == torch.float16: + # the adam buffers should also be close because they are 32-bit + # but the paramters can diverge because they are 16-bit + # the difference grow larger and larger with each update + # --> copy the state to keep weights close + p1.data = p1.data.half().float() + p2.copy_(p1.data) + torch.testing.assert_allclose(p1.half(), p2) + if optim_name in ['lars', 'lamb']: + assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0 + +dim1 = [1024] +dim2 = [32, 1024, 4097] +gtype = [torch.float32, torch.float16] +values = list(product(dim1,dim2, gtype)) +names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) +def test_global_config(dim1, dim2, gtype): + if dim1 == 1 and dim2 == 1: return + p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + mask = torch.rand_like(p2) < 0.1 + beta1 = 0.9 + beta2 = 0.999 + lr = 0.001 + eps = 1e-8 + + bnb.optim.GlobalOptimManager.get_instance().initialize() + bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) + + bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) + p1 = p1.cuda() + p2 = p2.cuda() + p3 = p3.cuda() + + adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps) + + if gtype == torch.float32: + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 1e-4, 1e-3 + + for i in range(50): + g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 + g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 + g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 + p1.grad = g1 + p2.grad = g2 + p3.grad = g3 + + adam2.step() + + assert adam2.state[p3]['state1'].dtype == torch.uint8 + assert adam2.state[p3]['state2'].dtype == torch.uint8 + + + +dim1 = [1024] +dim2 = [32, 1024, 4097] +gtype = [torch.float32, torch.float16] +optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] +values = list(product(dim1,dim2, gtype, optimizer_names)) +names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +def test_optimizer8bit(dim1, dim2, gtype, optim_name): + if dim1 == 1 and dim2 == 1: return + p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + p2 = p1.clone() + p1 = p1.float() + blocksize = 2048 + + torch_optimizer = str2optimizers[optim_name][0]([p1]) + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + + if gtype == torch.float32: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-5, 1e-3 + + else: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-5, 1e-3 + + errors = [] + relerrors = [] + + for i in range(50): + g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + p1.grad = g.clone().float() + p2.grad = g.clone() + + bnb_optimizer.step() + torch_optimizer.step() + + torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + + dequant_states = [] + for name1, name2, qmap, max_val in str2statenames[optim_name]: + #print(bnb_optimizer.state[p2][max_val], name1) + if 'blockwise' in optim_name: + s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) + else: + s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 + assert num_not_close.sum().item() < 20 + dequant_states.append(s1.clone()) + + err = torch.abs(p1-p2) + relerr = err/torch.abs(p1) + assert err.mean() < 0.0001 + assert relerr.mean() < 0.001 + + errors.append(err.mean().item()) + relerrors.append(relerr.mean().item()) + + if i % 10 == 0 and i > 0: + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + s1cpy = s.clone() + raws1cpy = bnb_optimizer.state[p2][name2].clone() + qmap1 = bnb_optimizer.state[p2][qmap].clone() + + path = get_temp_dir() + torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) + del bnb_optimizer + bnb_optimizer = None + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) + rm_path(path) + torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) + + if 'blockwise' in optim_name: + s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) + else: + s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) + torch.testing.assert_allclose(s1cpy, s1) + + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 + assert num_not_close.sum().item() < 20 + torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + + # the parameters diverge quickly. Here we keep them close + # together so we can test against the Adam error + p1.data = p1.data.to(gtype).float() + p2.copy_(p1.data) + torch.testing.assert_allclose(p1.to(gtype), p2) + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + torch_optimizer.state[p1][name1].copy_(s.data) + + #print(sum(errors)/len(errors)) + #print(sum(relerrors)/len(relerrors)) + + + +dim1 = [1024] +dim2 = [32, 1024, 4097] +gtype = [torch.float32] +optim_bits = [32, 8] +values = list(product(dim1,dim2, gtype, optim_bits)) +names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) +def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): + if dim1 == 1 and dim2 == 1: return + p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + beta1 = 0.9 + beta2 = 0.999 + lr = 0.001 + eps = 1e-8 + p1 = p1.cuda() + p2 = p1.clone() + adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) + adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) + + gnorm_vec = torch.zeros(100).cuda() + step = 0 + + for i in range(50): + step += 1 + g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i) + g2 = g1.clone() + p2.grad = g2 + + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) + g1 = (g1.float()*gnorm_scale).to(gtype) + p1.grad = g1 + + adam1.step() + adam2.step() + + # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state + if optim_bits == 32: + torch.testing.assert_allclose(p1, p2) + torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4) + torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4) + elif optim_bits == 8: + torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) + torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3) + torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3) + adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1']) + adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2']) + if i % 10 == 0 and i > 0: + path = get_temp_dir() + torch.save(adam2.state_dict(),join(path, 'opt.pt')) + del adam2 + adam2 = None + adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) + adam2.load_state_dict(torch.load(join(path, 'opt.pt'))) + + + + +dim1 = [4096] +dim2 = [4096] +gtype = [torch.float32, torch.float16] +#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] +#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] +#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] +#optimizer_names = ['lamb_apex', 'lamb8bit'] +#optimizer_names = ['lars_apex', 'lars8bit'] +optimizer_names = ['adam8bit_blockwise'] +values = list(product(dim1,dim2, gtype, optimizer_names)) +names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): + if dim1 == 1 and dim2 == 1: return + p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + + + bnb_optimizer = str2optimizers[optim_name][1]([p1]) + + g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + p1.grad = g + for i in range(5000): + if i == 500: + # 100 iterations for burn-in + torch.cuda.synchronize() + t0 = time.time() + + bnb_optimizer.step() + + torch.cuda.synchronize() + s = time.time()-t0 + print('') + params = 4500*4096*4096 + print(optim_name, gtype, s/params) + #assert s < 3.9 + + |