From c771b3a75a6ebbfbfc398a028a477246b0799cf0 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 14:41:05 -0700 Subject: Most tests passing. --- tests/test_autograd.py | 270 +++++++ tests/test_functional.py | 1763 ++++++++++++++++++++++++++++++++++++++++++++-- tests/test_modules.py | 478 ++++++++++++- tests/test_optim.py | 87 +-- 4 files changed, 2446 insertions(+), 152 deletions(-) create mode 100644 tests/test_autograd.py (limited to 'tests') diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..d2b5d59 --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,270 @@ +import pytest + +import torch +import bitsandbytes as bnb + +from itertools import product + +n = 1 +k = 25 +dim1 = torch.randint(16,64, size=(n,)).tolist() +dim2 = torch.randint(32,96, size=(n,)).tolist() +dim3 = torch.randint(32,96, size=(n,)).tolist() +dim4 = torch.randint(32,96, size=(n,)).tolist() +funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] +str_funcs = ['bmm', 'matmul'] +req_grad = [(False, False), (True, False), (True, True), (False, True)] +req_grad_str = ['FF', 'TF', 'TT', 'FT'] +transpose = [(False, False), (False, True), (True, True), (True, False)] +str_transpose = ['FF', 'FT', 'TT', 'TF'] +dtype = [torch.float32, torch.float16] +values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose)) +str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0]) + B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) + target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1]) + torch.nn.init.xavier_uniform_(B) + + if not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + elif not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t()) + elif transpose[0] and not transpose[1]: + out_torch = funcs[0](A.t(), B) + out_bnb = funcs[1](A.t(), B) + elif transpose[0] and transpose[1]: + out_torch = funcs[0](A.t(), B.t()) + out_bnb = funcs[1](A.t(), B.t()) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx==0).sum().item() < n*0.001 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + + # batched matrix multiply + if funcs[0] in [torch.bmm, torch.matmul]: + A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) + B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1]) + target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + torch.nn.init.xavier_uniform_(B) + + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.01 + torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + + if funcs[0] in [torch.matmul]: + dim1 = dim1 - (dim1 % 16) + A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) + dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) + B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) + target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + torch.nn.init.xavier_uniform_(B) + + if transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t()) + else: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx==0).sum().item() < n*0.001 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + + +n = 1 +k = 3 +dim1 = torch.randint(16,64, size=(n,)).tolist() +dim2 = torch.randint(32,96, size=(n,)).tolist() +dim3 = torch.randint(32,96, size=(n,)).tolist() +dim4 = torch.randint(32,96, size=(n,)).tolist() + +#dim1 = (17,) +#dim2 = (7,) +#dim3 = (37,) +#dim4 = (23,) + +decomp = [0.0, 6.0] +funcs = [(torch.matmul, bnb.matmul)] +str_funcs = ['matmul'] +req_grad = [(False, False), (True, False), (True, True), (False, True)] +req_grad_str = ['FF', 'TF', 'TT', 'FT'] +transpose = [(False, True), (False, False)] +str_transpose = ['NT', 'NN'] +dtype = [torch.float16] +has_fp16_weights = [True, False] +values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights)) +str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names) +def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda') + + for i in range(k): + + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype) + if decomp == 6.0: + with torch.no_grad(): + A[:, outlier_dim] = 6.0 + B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype) + torch.nn.init.xavier_uniform_(B) + B2 = B.clone() + + state = bnb.MatmulLtState() + state.threshold = decomp + state.has_fp16_weights = has_fp16_weights + if not has_fp16_weights: + if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() + state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2) + B2 = state.CB + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2, state=state) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2.t(), state=state) + + n = out_bnb.numel() + err = torch.abs(out_bnb-out_torch).mean().item() + #print(f'abs error {err:.4f}') + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx==0).sum().item() < n*0.001 + + if has_fp16_weights: + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + diff --git a/tests/test_functional.py b/tests/test_functional.py index 2a7d308..6cbe58f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,15 +1,76 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import pytest +import math +import random +import time import torch import bitsandbytes as bnb +import einops from itertools import product from bitsandbytes import functional as F +torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) +k = 20 + +def assert_all_approx_close(a, b, rtol, atol, count): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx==0).sum().item() + if sumval > count: + print(f'Too many values not close: assert {sumval} < {count}') + torch.testing.assert_allclose(a, b, rtol, atol) + +class FFN(torch.nn.Module): + def __init__(self, input_features, hidden_size, bias=True): + super(FFN, self).__init__() + self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) + self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) + + with torch.no_grad(): + torch.nn.init.xavier_uniform_(self.fc1.weight) + torch.nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = self.fc2(x) + return x + +class Timer(object): + def __init__(self): + self.starts = {} + self.ends = {} + self.agg = {} + + def tick(self, name='default'): + if name not in self.starts: + self.starts[name] = torch.cuda.Event(enable_timing=True) + self.ends[name] = torch.cuda.Event(enable_timing=True) + self.starts[name].record() + else: + ms = self.tock(name, evict=True, print_ms=False) + + def tock(self, name='default', evict=True, print_ms=True): + if name in self.ends: + self.ends[name].record() + torch.cuda.synchronize() + ms = self.starts[name].elapsed_time(self.ends[name]) + if name not in self.agg: self.agg[name] = 0.0 + self.agg[name] += ms + if evict: + self.starts.pop(name) + self.ends.pop(name) + + if print_ms and name in self.agg: + print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0)) + + return self.agg[name] + + def reset(self): + self.starts = {} + self.ends = {} + self.agg = {} + print('Resetting benchmark data') + def setup(): pass @@ -64,8 +125,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + #print(sum(diffs)/len(diffs)) + #print(sum(reldiffs)/len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device='cuda') @@ -88,8 +149,8 @@ def test_dynamic_blockwise_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diffs[-1] < 0.011 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + #print(sum(diffs)/len(diffs)) + #print(sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -125,7 +186,7 @@ def test_percentile_clipping(gtype): n = 4 step = 0 percentile=5 - for i in range(1000): + for i in range(k): step += 1 g = torch.randn(n, n, dtype=gtype, device='cuda') gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) @@ -145,69 +206,1653 @@ def test_percentile_clipping(gtype): torch.testing.assert_allclose(gnorm1, gnorm2) +def quant(x): + max1 = torch.abs(x).max() + x = torch.round(x/max1*127) + return max1, x.to(torch.int8) + +def dequant(c, maxC): + return c.float()*(maxC/127) + +def mm_dequant(maxA, maxB, C): + return C.float()*(maxA/127)*(maxB/127) + +def quant_multi(x, dim): + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1==0] = 1.0 + x = torch.round(x/max1*127) + return max1, x.to(torch.int8) + +def quant_multi_chunk(x, dim, chunk_size=32): + if dim==1: + x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True) + max1 = torch.tile(max1, (1, 1, x.shape[1])) + max1 = max1.view(x.shape) + elif dim==0: + x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) + max1 = torch.tile(max1, (x.shape[0], 1, 1)) + max1 = max1.view(x.shape) + max1[max1==0] = 1.0 + x = torch.round(x/max1*127) + return max1, x.to(torch.int8) + +def quant_minmax(A): + minA = A.min() + maxA = A.max() + +def mean(xx): + return sum(xx)/float(len(xx)) + +#dim1 = torch.randint(1,1024*4, size=(4,)).tolist() +#dim2 = torch.randint(1,1024*4, size=(4,)).tolist() +dim1 = [1024*2] +dim2 = [1024*16] +methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)] +methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) +#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) +method_names = ['linear', 'vectorwise'] +batched = [False, True] +values = list(product(dim1,dim2, methods, batched)) +values_names = list(product(dim1,dim2, method_names, batched)) +names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names] +@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) +def test_approx_igemm(dim1, dim2, quant_methods, batched): + dim1 = dim1 - (dim1 % 32) + dim2 = dim2 - (dim2 % 32) + errors = [] + relerrors = [] + print('') + for i in range(5): + if batched: + A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda') + B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda') + maxA, Ac = quant_methods[0](A, 2) + maxB, Bc = quant_methods[1](B, 1) + else: + A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda') + B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda') + maxA, Ac = quant_methods[0](A, 1) + maxB, Bc = quant_methods[1](B, 0) + torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) + if batched: + out2 = torch.bmm(A, B) + C = torch.bmm(Ac.float(), Bc.float()) + else: + out2 = torch.mm(A, B) + C = F.igemm(Ac, Bc) + out = quant_methods[4](maxA, maxB, C) + std = out2.std() + out/= std + out2/= std + err = torch.abs(out-out2) + relerr = err/torch.abs(out2) + errors.append(err.mean().item()) + relerrors.append(relerr.mean().item()) + print(mean(errors)) + print(mean(relerrors)) + + + + + + def test_stable_embedding(): layer = bnb.nn.StableEmbedding(1024, 1024) layer.reset_parameters() -def test_dynamic_blockwise_quantization_cpu(): - #A1 = torch.randn(1024, 1024, device='cpu') - #code = F.create_dynamic_map() - #for i in range(1000): - # C, S = F.quantize_blockwise(A1, code=code) - # A2 = F.dequantize_blockwise(C, S) - for i in range(10): - # equivalence with GPU blockwise quantization - A1 = torch.randn(1024, 1024, device='cpu') - C1, S1 = F.quantize_blockwise(A1) - C2, S2 = F.quantize_blockwise(A1.cuda()) - torch.testing.assert_allclose(S1[0], S2[0].cpu()) - # there seems to be some issues with precision in CUDA vs CPU - # not all elements are usually close, with couple off elements in a million - idx = torch.isclose(C1, C2.cpu()) - assert (idx==0).sum().item() < 15 +n = 2 +hidden_dim = torch.randint(32,256, size=(n,)).tolist() +batch_dim = torch.randint(16,256, size=(n,)).tolist() +seq_dim = torch.randint(16,256, size=(n,)).tolist() +transpose = [(False, False), (False, True), (True, False), (True, True)] +values = list(product(hidden_dim,batch_dim, transpose, seq_dim)) +names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) +def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 16) + seq_dim = seq_dim - (seq_dim % 16) + for i in range(k): + shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + elif transpose[0] and not transpose[1]: + out2 = torch.matmul(A.t().float(), B.float()) + out = F.igemm(A.t(), B) + elif transpose[0] and transpose[1]: + out2 = torch.matmul(A.t().float(), B.t().float()) + out = F.igemm(A.t(), B.t()) + torch.testing.assert_allclose(out.float(), out2) - diffs = [] - reldiffs = [] + for i in range(k): + shapeA = (batch_dim, seq_dim, hidden_dim) + shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + + torch.testing.assert_allclose(out.float(), out2) + + +n = 3 +seq_dim = torch.randint(32,512, size=(n,)).tolist() +hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() +batch_dim = torch.randint(2,16, size=(n,)).tolist() +values = list(product(seq_dim,hidden_dim,batch_dim)) +names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) +def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): + seq_dim = seq_dim - (seq_dim % 32) + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 2) + for i in range(25): + A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8) + out2 = torch.einsum('bsi, bso->io', A.float(), B.float()) + iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) + out = F.igemm(A, B, out=iout) + + torch.testing.assert_allclose(out.float(), out2) + +n = 2 +seq_dim = torch.randint(32,512, size=(n,)).tolist() +hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() +batch_dim = torch.randint(2,16, size=(n,)).tolist() +transpose = [False, True] +values = list(product(seq_dim,hidden_dim,batch_dim, transpose)) +names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) +def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): + + def min_max(x): + maxA = torch.amax(x, dim=2, keepdim=True) + minA = torch.amin(x, dim=2, keepdim=True) + scale = (maxA-minA)/2.0 + return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale + + seq_dim = seq_dim - (seq_dim % 16) + hidden_dim = hidden_dim - (hidden_dim % 16) + batch_dim = batch_dim - (batch_dim % 2) + errs = [] + relerrs = [] + errs2 = [] + relerrs2 = [] + for i in range(k): + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda') + if transpose: + B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda') + else: + B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda') + Ac, minA, scale = min_max(A) + if transpose: + maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) + out = F.igemm(Ac, Bc.t()) + out2 = torch.matmul(A,B.t()) + offset = B.t().sum(0)*(minA+scale) + out = out.float() + out = (out*maxB.t()*scale/(127*127))+offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc.t()) + out3 = mm_dequant(maxA, maxB.t(), out3) + else: + maxB, Bc = quant_multi(B, dim=0) + offset = B.sum(0)*(minA+scale) + out = F.igemm(Ac, Bc) + out2 = torch.matmul(A,B) + out = out.float() + out = (out*maxB*scale/(127*127))+offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc) + out3 = mm_dequant(maxA, maxB, out3) + + std = out2.std() + out2 /= std + out /= std + out3 /= std + + err = torch.abs(out-out2) + relerr = err/(torch.abs(out2)+1e-7) + + err2 = torch.abs(out3-out2) + relerr2 = err2/(torch.abs(out2)+1e-7) + + errs.append(err.mean().item()) + relerrs.append(relerr.mean().item()) + errs2.append(err2.mean().item()) + relerrs2.append(relerr2.mean().item()) + #print(mean(errs)) + #print(mean(relerrs)) + #print(mean(errs2)) + #print(mean(relerrs2)) + assert mean(errs) < 0.015 + assert mean(relerrs) < 0.3 + +n = 2 +dim1 = torch.randint(1,64, size=(n,)).tolist() +dim2 = torch.randint(32,128, size=(n,)).tolist() +dim3 = torch.randint(32,256, size=(n,)).tolist() +dim4 = torch.randint(32,256, size=(n,)).tolist() +transpose = [(False, False), (True, False), (False, True), (True, True)] +values = list(product(dim1,dim2,dim3,dim4,transpose)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) +def test_ibmm(dim1, dim2, dim3, dim4, transpose): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) + shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + + if not transpose[0] and not transpose[1]: + out2 = torch.bmm(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A, B.permute([0, 2, 1])) + elif transpose[0] and not transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) + out = F.igemm(A.permute([0, 2, 1]), B) + elif transpose[0] and transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) + torch.testing.assert_allclose(out.float(), out2.float()) + +n = 1 +dim1 = torch.randint(1,64, size=(n,)).tolist() +dim2 = torch.randint(32,128, size=(n,)).tolist() +dim3 = torch.randint(32,256, size=(n,)).tolist() +values = list(product(dim1,dim2,dim3)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) +def test_vector_quant(dim1, dim2, dim3): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + for i in range(k): + A = torch.randn(size=(dim2, dim3), device='cuda') + qA, SA = F.vectorwise_quant(A, dim=0) + A1 = F.vectorwise_dequant(qA, SA) + torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1) + + + +n = 2 +dim1 = torch.randint(2,256, size=(n,)).tolist() +dim2 = torch.randint(2,256, size=(n,)).tolist() +dim3 = torch.randint(2,256, size=(n,)).tolist() +#dim1, dim2 = (256,), (256,) +dtype = [torch.int8, torch.int32] +a_order = ['row'] +out_order = ['col', 'row', 'col32'] +transpose = [False] +dims = [2, 3] +values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) + +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + if dims == 3 and out_order != 'col32': return + if dtype == torch.int32 and out_order != 'col32': return + func = F.get_transform_func(dtype, orderA, orderOut, transpose) + + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype) + + out, S = F.nvidia_transform(A, to_order=orderOut) + + if orderOut == 'row': + torch.testing.assert_allclose(A.flatten(), out.flatten()) + elif orderOut == 'col': + torch.testing.assert_allclose(A.t().flatten(), out.flatten()) + elif orderOut == 'col32': + if dims == 2: + n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32))) + elif dims == 3: + n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32))) + assert out.numel() == n + elif orderOut == 'col_turing': + # 32 col 8 row tiles + n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32))) + assert out.numel() == n + total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) + for row in range(A.shape[0]): + for col in range(A.shape[1]): + i = row*A.shape[1] + j = col + + coltile = (col // 32) + (1 if col % 32 != 0 else 0) + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile + offset = 32*8*(rowtile+coltile) + col2 = col % 32 + row2 = (row%8)*32 + + + assert A.flatten()[i+j] == A[row, col] + #assert A.flatten()[i+j] == out.flatten()[row2+col2] + #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) + #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) + + if orderOut == 'col32': + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S) + torch.testing.assert_allclose(A, out2) + + +n = 1 +dim1 = torch.randint(1,256, size=(n,)).tolist() +dim2 = torch.randint(32,512, size=(n,)).tolist() +dim3 = torch.randint(32,1024, size=(n,)).tolist() +dim4 = torch.randint(32,1024, size=(n,)).tolist() + +#dim1 = [2] +#dim2 = [2] +#dim3 = [2] +#dim4 = [2] + +dims = (2,3) +ldb = [0] +#ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) +def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): + for i in range(k): + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + C1 = torch.matmul(A.float(), B.t().float()) + + A2, SA = F.transform(A, 'col32') + B2, SB = F.transform(B, 'col_turing') + C2, SC = F.igemmlt(A2, B2, SA, SB) + C3, S = F.nvidia_transform(C2, 'row', state=SC) + torch.testing.assert_allclose(C1, C3.float()) + + # transpose + B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + C1 = torch.matmul(A.float(), B.float()) + + B2t, SBt = F.transform(B, 'col_turing', transpose=True) + C2, SC = F.igemmlt(A2, B2t, SA, SBt) + C3, S = F.nvidia_transform(C2, 'row', state=SC) + torch.testing.assert_allclose(C1, C3.float()) + +dim1 = [32] +dim2 = [32] +dim3 = [32] +dim4 = [32] + +dims = (2,) +#ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1,dim2,dim3,dim4,dims)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) +def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): + formatB = F.get_special_format_str() + for i in range(k): + if dims == 2: + A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half() + elif dims == 3: + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half() + B = torch.randn((dim4, dim3), device='cuda').half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + C2 = bnb.matmul(A, B.t()) + + A = A.view(-1, A.shape[-1]) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + C32A, SA = F.transform(CA, 'col32') + CxB, SB = F.transform(CB, to_order=formatB) + out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) + output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) + + #print('') + #print(output.flatten()[:10]) + #print(C1.flatten()[:10]) + #print(C2.flatten()[:10]) + + + #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + + # transpose + #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + #C1 = torch.matmul(A.float(), B.float()) + + #B2t, SBt = F.transform2(B, 'col_turing', transpose=True) + #C2, SC = F.igemmlt(A2, B2t, SA, SBt) + #C3, S = F.transform(C2, 'row', state=SC) + #torch.testing.assert_allclose(C1, C3.float()) + +batch_size = 2 +seqdim = 512 +#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] + + +#values = list(product(batch, seq, model, hidden)) +names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +def test_bench_8bit_training(batch, seq, model, hidden): + formatB = F.get_special_format_str() + A = torch.randn(batch, seq, model, device='cuda').half() + grad = torch.randn(batch, seq, model, device='cuda').half() + w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half() + w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half() + print('') + + #torch.cuda.synchronize() + ## warmup + #for i in range(100): + # torch.matmul(A, w1.t()) + #torch.cuda.synchronize() + + dtype = torch.int8 + A = A.view(-1, A.shape[-1]).contiguous() + grad = grad.view(-1, grad.shape[-1]).contiguous() + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + + out1 = torch.matmul(A, w1.t()) # fc1 + #out2 = torch.matmul(out1, w2.t())# fc2 + + #d1 = torch.matmul(grad, w2) # delta1 + #d2 = torch.matmul(d1, w1) # delta2 + + #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 + #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 + + torch.cuda.synchronize() + t16 = time.time() - t0 + print(t16) + + #torch.cuda.empty_cache() + + #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + #CTw1, Sw1 = F.transform2(Cw1, formatB) + #CTw2, Sw2 = F.transform2(Cw2, formatB) + #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + + #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + #C32A, SA = F.transform2(CA, 'col32') + ## fc1 + #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) + + ## fc2 + #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + #C32out1, Sout1 = F.transform2(Cout1, 'col32') + #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) + + ## delta1 + #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + #C32grad, Sgrad = F.transform2(Cgrad, 'col32') + ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) + + ## delta2 + #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + #C32d1, Sd1 = F.transform2(Cd1, 'col32') + ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) + + ## grad1 + #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) + + ## grad2 + #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) + + #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + #CTw1, Sw1 = F.transform2(Cw1, formatB) + #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + #CTw2, Sw2 = F.transform2(Cw2, formatB) + #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(k): + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + + # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5) + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # #CTw2, Sw2 = F.transform2(Cw2, formatB) + # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + + # C32A, SA = F.transform2(CA, 'col32') + + # # fc1 + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + # #print(coo_tensor.nnz) + # #out1sp = F.spmm_coo(coo_tensor, w1.t()) + # #print(w1.t().shape) + # #out1 = out1dn + out1sp + + # # fc2 + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2) + + # # delta1 + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') + # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t) + + # # delta2 + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') + # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t) + + # # grad1 + # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt) + + # ## grad2 + # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) + + #torch.cuda.synchronize() + #t8 = time.time() - t0 + #print(t8) + + + + + +n = 2 +dim1 = torch.randint(64,256, size=(n,)).tolist() +dim4 = torch.randint(64,1024, size=(n,)).tolist() + +#dim1 = [2*1024] +#dim4 = [2*1024] + +#dim1 = [4] +#dim4 = [4] + +dims = (2,) +#ldb = list(range(256, 1*1024, 256)) +formatB = ['col_turing', 'col_ampere'] +values = list(product(dim1,dim4,dims, formatB)) +names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) +def test_dequant_mm(dim1, dim4, dims, formatB): + inner = torch.randint(1, 128, size=(1,)).item() + formatB = F.get_special_format_str() + for i in range(k): + A = torch.randn(dim1, inner, device='cuda') + B = torch.randn(dim4, inner, device='cuda') + C1 = torch.matmul(A.half(), B.t().half()) + + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + A2, SA = F.nvidia_transform(A1, 'col32') + B2, SB = F.nvidia_transform(B1, formatB) + C2, SC = F.igemmlt(A2, B2, SA, SB) + + C3, S = F.nvidia_transform(C2, 'row', state=SC) + C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + + count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() + n = C1.numel() + p = 0.06 + assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}' + + C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) + torch.testing.assert_allclose(C5, C4) + #print(C2) + + + +n = 2 +dim1 = [1*1024] +dim2 = [1*1024] +#dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() + +dims = (2,) +#ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1,dim2,dims)) +names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) +def test_colrow_absmax(dim1, dim2, dims): + for i in range(k): + threshold = 3.0 + A = torch.randn(dim1, dim2, device='cuda').half() + A_truncated = A.clone() + A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 + if dims == 2: + row_stats1, _ = torch.abs(A.float()).max(1) + col_stats1, _ = torch.abs(A.float()).max(0) + row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) + col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) + else: + assert False + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + + A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4) + nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten() + nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device) + nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) + + torch.testing.assert_allclose(col_stats1_trunc, col_stats2) + torch.testing.assert_allclose(row_stats1_trunc, row_stats2) + torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + + torch.testing.assert_allclose(col_stats1, col_stats2) + torch.testing.assert_allclose(row_stats1, row_stats2) + assert nnz_block_ptr2 is None + + + +n = 2 +#dim1 = [8*1024] +#dim2 = [4*1024] +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim2 = torch.randint(1,4*1024, size=(n,)).tolist() + +values = list(product(dim1,dim2)) +names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_double_quant(dim1, dim2): + for i in range(k): + A = torch.randn(dim1, dim2, device='cuda').half() + out_col1, Scol = F.vectorwise_quant(A, dim=0) + out_row1, Srow = F.vectorwise_quant(A, dim=1) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + + # max difference is 1 due to rounding differences + torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) + torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) + + + n = CAt.numel() + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item() + + # allow for 1:500 error due to rounding differences + min_error = 1/500 + if num_not_close_cols > (min_error*n): + print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}') + assert False + if num_not_close_rows > (min_error*n): + print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}') + assert False + + torch.testing.assert_allclose(Srow.flatten(), statsA) + torch.testing.assert_allclose(Scol.flatten(), statsAt) + + +n = 4 +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim4 = torch.randint(1,4*1024, size=(n,)).tolist() +inner = torch.randint(1,4*1024, size=(n,)).tolist() + +dim1 = [6] +dim4 = [4] +inner = [8] + +values = list(zip(dim1, dim4, inner)) +names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_integrated_igemmlt(dim1, dim4, inner): + for i in range(k): + A = torch.randn(dim1, inner, device='cuda').half() + B = torch.randn(dim4, inner, device='cuda').half() + + out1 = torch.matmul(A.half(), B.t().half()) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + torch.testing.assert_allclose(maxA.flatten(), stats1a) + torch.testing.assert_allclose(maxB.flatten(), stats2a) + torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) + torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) + + A2, SA = F.nvidia_transform(C1a, 'col32') + B2, SB = F.nvidia_transform(C2a, 'col_turing') + outC32, SC = F.igemmlt(A2, B2, SA, SB) + out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + + A2, SA = F.nvidia_transform(A1, 'col32') + B2, SB = F.nvidia_transform(B1, 'col_turing') + C2, SC = F.igemmlt(A2, B2, SA, SB) + + C3, S = F.nvidia_transform(C2, 'row', state=SC) + out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + + err1 = torch.abs(out1-out2).mean().item() + err2 = torch.abs(out1-out3).mean().item() + assert err2 <= err1*1.01 + + +n = 6 +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim4 = torch.randint(1,4*1024, size=(n,)).tolist() +inner = torch.randint(1,4*1024, size=(n,)).tolist() + +values = list(zip(dim1, dim4, inner)) +names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_igemmlt_row_scale(dim1, dim4, inner): + formatB = F.get_special_format_str() + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + for i in range(k): + A = torch.randn(dim1, inner, device='cuda').half() + B = torch.randn(dim4, inner, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + + out1 = torch.matmul(A.half(), B.t().half()) + + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') + A2, SA = F.nvidia_transform(C1a, 'col32') + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0*inner*scale + row_scale = torch.ones_like(maxA)/c + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + C3, S = F.nvidia_transform(outC32, 'row', state=SC) + maxval = torch.abs(C3).max() + if maxval == 127: + scale = 1.5 + else: + scale = maxval/120 + out3 = C3*maxA*absmaxB*c/(127*127) + + C4 = torch.matmul(C1a.float(), CB.float().t()) + + + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + outC32, SC = F.igemmlt(A2, B2, SA, SB) + out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + + CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector') + CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear') + + C = torch.matmul(CA.float(), CB.t().float()) + out4 = C*SA*SB/(127*127) + #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) + + #print('='*80) + #print(out1) + #print(out2) + #print(out3) + + #print(out1) + #print(out2) + #print(out3) + err1.append(torch.abs(out1-out2).mean().item()) + err2.append(torch.abs(out1-out3).mean().item()) + err3.append(torch.abs(out1-out4).mean().item()) + + #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) + print('') + print(sum(err1)/len(err1)) + print(sum(err2)/len(err2)) + print(sum(err3)/len(err3)) + + +dim1 = [1024, 2048] +inner = [12288*4, 4096*4] +dim4 = [12288, 4096] + +values = list(zip(dim1, dim4, inner)) +names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_row_scale_bench(dim1, dim4, inner): + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + A = torch.randn(dim1, inner, device='cuda').half() + B = torch.randn(dim4, inner, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + # warmpup + for i in range(k): + C1 = torch.matmul(A, B.t()) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = torch.matmul(A, B.t()) + torch.cuda.synchronize() + print('16', time.time()-t0) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') + A2, SA = F.nvidia_transform(C1a, 'col32') + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0*inner*scale + row_scale = maxA/c + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + torch.cuda.synchronize() + print('row-wise', time.time()-t0) + + + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32, SC = F.igemmlt(A2, B2, SA, SB) + torch.cuda.synchronize() + print('vector-wise', time.time()-t0) + + + + +n = 2 +dim1 = torch.randint(2,1024, size=(n,)).tolist() +dim2 = torch.randint(2,1024, size=(n,)).tolist() +#dim1 = [8*1024] +#dim2 = [4*1024] + +dim3 = [0] +dtype = [torch.int8] +a_order = ['row'] +out_order = ['col32', 'col_turing', 'col_ampere'] +transpose = [False, True] +dims = [2] +values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + for i in range(k): + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype) + elif dims == 3: + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype) + + A.view(-1)[-1] = -1 + if transpose: + At = A.t().contiguous() + out1, S1 = F.nvidia_transform(At, to_order=orderOut) + else: + out1, S1 = F.nvidia_transform(A, to_order=orderOut) + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) + + assert S1[0][0] == S2[0][0] + assert S1[0][1] == S2[0][1] + #print(out1) + #print(out2) + + torch.testing.assert_allclose(out1, out2) + +n = 2 +#dim1 = torch.randint(2,1024, size=(n,)).tolist() +#dim2 = torch.randint(2,1024, size=(n,)).tolist() +dim1 = [1] +dim2 = [33] + +dtype = [torch.int8] +#a_order = ['col_turing', 'col_ampere'] +a_order = ['col_turing'] +out_order = ['row'] +values = list(product(dim1,dim2,dtype, a_order, out_order)) +names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) +def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): + for i in range(1): + A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype) + + out2, S2 = F.transform(A, to_order=orderA) + A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2) + assert A2.shape[0] == A.shape[0] + assert A2.shape[1] == A.shape[1] + + + print('') + print(A) + print(out2) + print(A2) + + + #torch.testing.assert_allclose(A, A2) + + + + +def test_overflow(): + formatB = F.get_special_format_str() + for i in range(2): + a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) + b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) + + Ca, Sa = F.nvidia_transform(a, 'col32') + Cb, Sb = F.nvidia_transform(b, formatB) + + c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + c2 = torch.matmul(a.float(), b.float().t()) + + +n = 2 +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +#dim1 = [4] +#dim2 = [5] + +values = list(product(dim1,dim2)) +names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_coo_double_quant(dim1, dim2): + threshold = 3.00 + for i in range(k): + A = torch.randn(dim1, dim2, device='cuda').half() + + idx = (torch.abs(A) >= threshold) + CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + + if coo_tensor is not None: + A1 = A*idx + A2 = torch.zeros_like(A) + A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + torch.testing.assert_allclose(A1, A2) + + A1 = A*(idx==0) + A2 = (CA.float()*statsA.unsqueeze(1)/127).half() + torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2) + +n = 2 +dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +dim2 = torch.randint(1,1*1024, size=(n,)).tolist() +#dim1 = [7] +#dim2 = [11] +transposed_B = [False, True] +values = list(product(dim1,dim2, transposed_B)) +names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) +def test_spmm_coo(dim1, dim2, transposed_B): + threshold = 1.5 + dim3 = torch.randint(32, 128, size=(1,)).item() + #dim3 = 17 + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + if transposed_B: + B = torch.randn(dim3, dim2).cuda().half() + else: + B = torch.randn(dim2, dim3).cuda().half() + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + + if transposed_B: + out2 = F.spmm_coo(cooA, B.t()) + out1 = torch.matmul(A2, B.t()) + else: + out2 = F.spmm_coo(cooA, B) + out1 = torch.matmul(A2, B) + + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) + + + +def test_spmm_bench(): + batch = 2 + model = 1024*1 + hidden = model*4 + seq = 1024 + dim1 = batch*seq + dim2 = model + dim3 = hidden + threshold = 4 + A = torch.randn(dim1, dim2, device='cuda').half() + B = torch.randn(dim2, dim3, device='cuda').half() for i in range(10): - A1 = torch.randn(1024, 1024, device='cpu') - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2) - reldiff = diff/torch.abs(A1+1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diffs[-1] < 0.011 - #print(sum(diffs)/len(diffs)) - #print(sum(reldiffs)/len(reldiffs)) + C1 = bnb.matmul(A, B) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = bnb.matmul(A, B) + torch.cuda.synchronize() + t8 = time.time()-t0 + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + print(nnz/idx.numel()) + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - diffs = [] for i in range(10): - A1 = torch.rand(1024, 1024, device='cpu') - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2).mean().item() - assert diff < 0.0033 - diffs.append(diff) - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - #print(sum(diffs)/len(diffs)) + out2 = F.spmm_coo(cooA, B) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + tsp = time.time()-t0 + print(tsp, t8) + print(tsp/t8) + + +n = 2 +dim1 = torch.randint(256,1*1024, size=(n,)).tolist() +dim2 = torch.randint(256,1*1024, size=(n,)).tolist() +values = list(product(dim1,dim2)) +names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_integrated_sparse_decomp(dim1, dim2): + threshold = 3.0 + formatB = 'col_turing' + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + w1 = torch.randn(dim1, dim2).cuda().half() + out1 = torch.matmul(A, w1.t()) + + Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + CTw1, Sw1 = F.transform(Cw1, formatB) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + C32A, SA = F.transform(CA, 'col32') + + out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + C32A, SA = F.transform(CA, 'col32') + + out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + assert coo_tensor is not None + + out4 = F.spmm_coo(coo_tensor, w1.t()) + out5 = out3 + out4 + + err1 = torch.abs(out1-out2).mean().item() + err2 = torch.abs(out1-out5).mean().item() + assert err2 < err1 + + +def test_matmuls(): + a = torch.randn(256, 256).half().cuda() + b = torch.randn(256, 256).half().cuda() + c1 = torch.matmul(a, b) + c2 = bnb.matmul(a, b) + c3 = bnb.matmul(a, b) + + err1 = torch.abs(c1-c2).mean().item() + err2 = torch.abs(c1-c3).mean().item() + assert err1 < 0.2 + assert err2 < 0.2 + + + +n = 2 +#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1*2048] +dim2 = [12288] +#dim1 = [32] +#dim2 = [32] +#dtype = [torch.float16, torch.int8] +dtype = [torch.float16] +out_function = ['zeros', 'ones'] +values = list(product(dim1,dim2, dtype, out_function)) +names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) +def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): + out_func = getattr(torch, out_func) + + threshold = 3.3 + #threshold = 2.8 + #threshold = 0.0 + A = torch.randn(dim1, dim2, device='cuda').half() + if dtype == torch.float16: + B = torch.randn(dim2, dim2*4, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + else: + B = torch.randn(dim2, dim2*4, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + B, SB = F.vectorwise_quant(B, quant_type='linear') + #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + + print('') + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + out1 = torch.matmul(A2.half(), B.half()) + out = out_func(out1.shape, dtype=torch.float16, device=out1.device) + out1 += out.clone() + out2 = F.spmm_coo_very_sparse(cooA, B, out=out) + #print(B) + #print(out1) + #print(out2) + p = 200/(2048*12288*4) + n = out1.numel() + count = math.ceil(p*n) + std = out1.std() + out1 /= std + out2 /= std + assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) + #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) + + idx_col = torch.randint(0, A2.shape[-1], size=(15,)) + + #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) + + #Bt = torch.randn(dim2*4, dim2, device='cuda').half() + #torch.cuda.synchronize() + #t0 = time.time() + #print(A2.shape, B.shape) + #for i in range(100): + # #out3 = F.spmm_coo(cooA, Bt.t()) + # #out2 = F.spmm_coo(cooA, B) + # #out2 = F.spmm_coo_very_sparse(cooA, B) + # #out1 = torch.matmul(A, Bt.t()) + + #torch.cuda.synchronize() + #print(time.time() - t0) + +def test_layout(): + a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16) + a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte() + a2, s2 = F.transform(a1, 'col_turing') + print(a2.shape) + + print(a1.flatten()[8*64:8*64+32]) + for i in range(4): + print(a2.flatten()[i*8*32:i*8*32+32], 0) + + +def test_coo2csr(): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + csrA = F.coo2csr(cooA) + counts = csrA.rowptr[1:] - csrA.rowptr[:-1] + assert counts.numel() == A.shape[0] + + torch.testing.assert_allclose(counts, (A2!=0).sum(1)) + idx = (A2!=0) + torch.testing.assert_allclose(A2[idx], csrA.values) + + +def test_coo2csc(): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + cscA = F.coo2csc(cooA) + counts = cscA.colptr[1:] - cscA.colptr[:-1] + assert counts.numel() == A.shape[1] + + torch.testing.assert_allclose(counts, (A2!=0).sum(0)) + # torch uses row-major -> use transpose to transfer to col-major + idx = (A2.t()!=0) + torch.testing.assert_allclose(A2.t()[idx], cscA.values) + + + +n = 2 +#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1*2048] +#dim2 = [12288] +dim2 = [2048] +#dim1 = [2] +#dim2 = [2] +dtype = [torch.int8] +values = list(product(dim1,dim2, dtype)) +names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) +def test_spmm_coo_dequant(dim1, dim2, dtype): + threshold = 6.0 + #threshold = 2.8 + #threshold = 0.0 + A = torch.randn(dim1, dim2, device='cuda').half() + B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16) + torch.nn.init.xavier_uniform_(B) + Bt = B.t().contiguous() + + + CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + + rowidx = torch.randint(0, A.shape[-1], size=(15,)) + + A[:, rowidx] = 8.0 + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out1 = torch.matmul(A2, B.half()) + out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) + out3 = out3*statsBt.half()/127 + + values, counts = torch.unique(cooA.rowidx, return_counts=True) + offset = counts.cumsum(0).int() + max_count, max_idx = torch.sort(counts, descending=True) + print(torch.median(max_count.float())) + + torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) + + p = 200/(2048*12288*4) + n = out1.numel() + count = math.ceil(p*n) + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) + + + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(100): + # out2 = F.spmm_coo_very_sparse(cooA, B) + #torch.cuda.synchronize() + #print('fp16', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + print('cusparse fp16', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt) + torch.cuda.synchronize() + print('int8', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + torch.cuda.synchronize() + print('int8+dequant', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = torch.matmul(A, B) + torch.cuda.synchronize() + print('matmul', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out = out1+out2 + torch.cuda.synchronize() + print('sparse+ matmul', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) + torch.cuda.synchronize() + print('partial matmul', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.cuda.synchronize() + print('partial matmul', time.time() - t0) + +batch_size = 1 +seqdim = 2048 +values = [] +values.append((batch_size, seqdim, 768, 4*768)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 12288, 4*12288)) +names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +def test_bench_matmul(batch, seq, model, hidden): + formatB = F.get_special_format_str() + + A = torch.randn(batch, seq, model, device='cuda').half() + B = torch.empty(hidden, model, dtype=torch.float16, device='cuda') + torch.nn.init.xavier_uniform_(B) + + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit.eval() + + outliers = torch.randint(0, model, size=(5,)).cuda() + A[:, :, outliers] = 8.0 + + linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + linearMixedBit.eval() + + # warmup + for i in range(100): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print('') + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + bnb.matmul(A, B) + torch.cuda.synchronize() + print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + C32A, SA = F.transform(CA, 'col32') + CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + CxB, SB = F.transform(CB, to_order=formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + torch.cuda.synchronize() + print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + BA, statsB = F.vectorwise_quant(B, dim=1) + CxB, SB = F.nvidia_transform(CB, to_order=formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + A2 = A.view(-1, A.shape[-1]).contiguous() + CA, statsA = F.vectorwise_quant(A2, dim=1) + C32A, SA = F.nvidia_transform(CA, 'col32') + out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) + F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + torch.cuda.synchronize() + print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear') + CxB, SB = F.nvidia_transform(CB, to_order=formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + A2 = A.view(-1, A.shape[-1]).contiguous() + CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear') + C32A, SA = F.nvidia_transform(CA, 'col32') + out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) + out = Cout*statsB*statsA*(1.0/(127*127)) + torch.cuda.synchronize() + print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + linear8bit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + linear8bit(A) + torch.cuda.synchronize() + print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + + linearMixedBit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + linearMixedBit(A) + torch.cuda.synchronize() + print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + +def test_zeropoint(): + def min_max(x): + maxA = torch.amax(x, dim=1, keepdim=True) + minA = torch.amin(x, dim=1, keepdim=True) + midpoint = (maxA-minA)/2.0 + dyna = 252/(maxA-minA) + #dyna *= 0.98 + x = dyna*x + x = x - torch.round((dyna*(minA+midpoint))) + return x.to(torch.int8), minA, midpoint, dyna + batch = 2 + seq = 2 + model = 4 + hidden = 2*model + #batch = 4 + #seq = 2048 + #model = 1024 + #hidden = 8*model + A = torch.randn(batch*seq, model, device='cuda').half()-0.4 + B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half()) + + #A[0] = 0 + #B[:, 0] = 0 + #A = A*(A>0) + #A[0, 0] = 0 + #A[0, 0] = 6.0 + + Ac, minA, midpoint, dyna = min_max(A) + #print(Ac[0, 0], 'zero') + #print(Ac, Ac.min(), Ac.max()) + Bc, maxB = F.vectorwise_quant(B, quant_type='linear') + out = F.igemm(Ac, Bc) + out2 = torch.matmul(A,B) + offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna + out = out.float() + #print(out.shape, maxB.shape, scale.shape, offset.shape) + norm1 = maxB/127 + C4 = (out/dyna)*norm1+offset + + + B1 = torch.nn.Parameter(B.clone()) + B2 = torch.nn.Parameter(B.clone()) + B3 = torch.nn.Parameter(B.clone()) + B4 = torch.nn.Parameter(B.clone()) + + + C1 = torch.matmul(A, B1) + C2 = bnb.matmul_cublas(A, B2, None, 'linear') + C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint') + C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint') + + err1 = torch.abs(C1-C2).mean().item() + err2 = torch.abs(C1-C3).mean().item() + err3 = torch.abs(C1-C4).mean().item() + print(err1, err2, err3) + #assert err1 > err2 + + loss1 = C1.mean() + loss2 = C2.mean() + loss3 = C3.mean() + loss4 = C4.mean() + + loss1.backward() + loss2.backward() + loss3.backward() + loss4.backward() + + print(B.grad) + print(B1.grad) + print(B2.grad) + print(B3.grad) + print(B4.grad) + err1 = torch.abs(B1.grad-B2.grad).mean().item() + err2 = torch.abs(B1.grad-B3.grad).mean().item() + err3 = torch.abs(B1.grad-B4.grad).mean().item() + print(err1, err2, err3) + + + + +def test_zp(): + def quant_zp(x): + dtype = x.dtype + x = x.float() + dyna = x.max() - x.min() + if dyna == 0: dyna = 1 + qx = 254./dyna + minx = x.min() + #zpx = torch.round(minx* qx) + #zpx = 127 - torch.round(x.max()* qx) + zpx = torch.round(x.min()* qx) - 127 + x = (qx*x) + zpx + return x, qx, zpx + batch = 2 + seq = 512 + model = 1024 + hidden = 4*model + A = torch.randn(batch*seq, model, device='cuda').half()*0.1 + B = torch.randn(model, hidden, device='cuda').half()*0.1 + + + C0 = torch.matmul(A, B) + + + #A, SA = F.vectorwise_quant(A, quant_type='linear') + #B, SB = F.vectorwise_quant(B, quant_type='linear') + A = A.float() + B = B.float() + + C1 = torch.matmul(A, B) + C3 = bnb.matmul(A.half(), B.t().contiguous().half()) + + zp = 1 + #C2 = torch.matmul(A-zp, B) + #C2 += B.sum(0).view(1, -1)*zp + C2 = torch.matmul(A, B-zp) + C2 -= A.sum(1).view(-1, 1)*zp + + ca, cqa, cza = quant_zp(A) + print(ca.min(), ca.max()) + print((ca-cza).min(), (ca-cza).max()) + + zp = 1 + scale = 2.0 + C5 = torch.matmul((A*scale)-zp, B) + C5 += B.sum(0)*zp + C5 /= scale + + CA, qa, zpa = quant_zp(A) + C4 = torch.matmul(CA, B) + C4 -= B.sum(0)*zpa + C4 /= qa + zpb = 1 + zpa = 1 + qa = 2 + qb = 2 + C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb) + C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) + C6 -= zpa*zpb*A.shape[1] + C6 /= qa*qb -def test_histogram(): - dim1, dim2 = 32, 32 - source = torch.rand(dim1, dim2, device='cuda') - idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() - idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() - histogram1 = torch.zeros((256, 256)).cuda() - histogram2 = torch.zeros((256, 256)).cuda() + CA, qa, zpa = quant_zp(A) + CB, qb, zpb = quant_zp(B) + C7 = torch.matmul(CA, CB) + C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) + C7 -= zpa*zpb*A.shape[1] + C7 /= qa*qb - F.histogram_scatter_add_2d(histogram2, idx1, idx2, source) + print('') + #print(C0.flatten()[:10]) + print(C1.flatten()[:10]) + print(C2.flatten()[:10]) + print(C3.flatten()[:10]) + print(C5.flatten()[:10]) + print(C6.flatten()[:10]) + print(C7.flatten()[:10]) + err1 = torch.abs(C1-C2).mean().item() + err2 = torch.abs(C1-C3).mean().item() + err3 = torch.abs(C1-C4).mean().item() + err4 = torch.abs(C1-C5).mean().item() + err5 = torch.abs(C1-C6).mean().item() + err6 = torch.abs(C1-C7).mean().item() + print(err1, err2, err3, err4, err5, err6) - for i in range(dim1): - for j in range(dim2): - histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j] - torch.testing.assert_allclose(histogram1, histogram2) - torch.testing.assert_allclose(histogram1.sum(), source.sum()) diff --git a/tests/test_modules.py b/tests/test_modules.py index a0379cb..a2c950b 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,42 +1,470 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import pytest import torch + +from itertools import product +from torch import nn + import bitsandbytes as bnb +class MockArgs(object): + def __init__(self, initial_data): + for key in initial_data: + setattr(self, key, initial_data[key]) + +class MLP8bit(torch.nn.Module): + def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + super(MLP8bit, self).__init__() + self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) + self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def get_args(): + args = MockArgs([]) + args.quant_type = 'vector' + args.use_8bit_training = 'full' + args.clip_freq = 9999 + return args + +def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx==0).sum().item() + if sumval > count: + print(f'Too many values not close: assert {sumval} < {count}') + torch.testing.assert_allclose(a, b, rtol, atol) + +class LinearFunction(torch.autograd.Function): + + @staticmethod + def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + norm = math.sqrt(math.pi)/math.sqrt(2.0) + #std = torch.abs(x).mean()*norm + std = torch.std(x) + max1 = std*trim_value + x = x/max1*127 + x = round_func(x) + x[x > 127] = 127 + x[x < -127] = -127 + x = x/127*max1 + + return x + + def quant(x, quant_type, dim=1): + if quant_type == 'linear': + max1 = torch.abs(x).max().float() + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type == 'vector': + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type == 'min-max': + maxA = torch.amax(x, dim=dim, keepdim=True).float() + minA = torch.amin(x, dim=dim, keepdim=True).float() + scale = (maxA-minA)/2.0 + xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) + return xq, (minA.float(), scale.float()) + else: return None + + def dequant(xq, S1, S2, dtype, quant_type): + if quant_type == 'linear': + norm = S1*S2/(127*127) + # double cast needed to prevent overflows + return (xq.float()*norm).to(dtype) + elif quant_type == 'vector': + x = xq.float() + if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) + if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) + #print(x.shape, S1.shape, S2.shape) + if len(S1.shape) == 2: + x *= S1.t()/127 + else: + x *= S1/127 + x *= S2/127 + return x.to(dtype) + else: return None + + def dequant_min_max(xq, A, B, SA, SB, dtype): + offset = B.float().t().sum(0)*(SA[0]+SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t()/127 + else: + x *= SB/127 + x *= SA[1]/127 + x +=offset + return x.to(dtype) + + + def get_8bit_linear(x, stochastic=False): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + max1 = torch.abs(x).max() + x = x/max1*127 + x = round_func(x)/127*max1 + #x = torch.round(x)/128*max1 + return x + + @staticmethod + def get_8bit_vector_wise(x, dim, stochastic=False): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1==0] = 1.0 + x = (x*127)/max1 + x = round_func(x)/127*max1 + return x + + @staticmethod + def round_stoachastic(x): + sign = torch.sign(x) + absx = torch.abs(x) + decimal = absx-torch.floor(absx) + rdm = torch.rand_like(decimal) + return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) + + @staticmethod + def fake_8bit_storage(w, exponent_bits): + code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_quantile(w, args): + code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) + #C = bnb.functional.quantize_no_absmax(code, w) + #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) + #print(out) + #out = out.half() + code /= torch.max(torch.abs(code)) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_stoachstic(w): + rand = torch.rand(1024, device=w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand) + out = bnb.functional.dequantize_blockwise(absmax, C) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_with_max(w, topk=8): + blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) + max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) + idx = idx[:, :topk] + max_val = max_val[:, :topk] + + mask = torch.zeros_like(blocked_w) + mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val)) + mask = mask.bool() + + # 1. zero out max values + # 2. quantize + dequantize + # 3. write back max values + # 4. copy matrix back to weight + + values = blocked_w[mask] + blocked_w[mask] = 0 + + code = bnb.functional.create_dynamic_map() + code = code.to(w.device) + absmax, C = bnb.functional.quantize_blockwise(blocked_w.data) + bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w) + + blocked_w[mask] = values + + unblocked_w = blocked_w.flatten().view(w.shape) + + w.copy_(unblocked_w) + return unblocked_w + + + @staticmethod + def forward(ctx, x, weight, bias=None, args=None): + if args.use_8bit_training != 'off': + weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) + outputq = bnb.functional.igemm(x8, weight8.t()) + output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) + #if torch.rand(1) < 0.01: + #output32 = torch.matmul(x, weight.t()) + #err = torch.abs(output-output32).float() + #relerr = err/(torch.abs(output32).float()+1e-8) + #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) + else: + #output = torch.matmul(x, weight.t()) + output = torch.einsum('bsi,oi->bso', x, weight) + + ctx.save_for_backward(x, weight, bias) + ctx.args = args + + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + args = ctx.args + stochastic = False + grad_input = grad_weight = grad_bias = None + if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) + + # weight and x are already 8bit + # -> transform grad_output to 8-bit + if args.use_8bit_training == 'forward+wgrad': + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = bnb.functional.igemm(grad_output8, x8) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + + #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) + + grad_input = grad_output.matmul(weight) + elif args.use_8bit_training == 'full': + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) + bnb.functional.igemm(grad_output8, x8, out=grad_weight8) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) + weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) + grad_input8 = bnb.functional.igemm(grad_output8, weight8) + grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) + + else: + grad_input = grad_output.matmul(weight) + grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) -@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) -def test_embeddings(embcls): - bnb.optim.GlobalOptimManager.get_instance().initialize() - emb1 = torch.nn.Embedding(100, 512).cuda() - emb2 = embcls(100, 512).cuda() + return grad_input, grad_weight, grad_bias, None - adam1 = bnb.optim.Adam8bit(emb1.parameters()) - adam2 = bnb.optim.Adam8bit(emb2.parameters()) +class Linear8bit(nn.Module): + def __init__(self, input_features, output_features, bias=True, args=None): + super(Linear8bit, self).__init__() + self.input_features = input_features + self.output_features = output_features + self.args = args - batches = torch.randint(1, 100, size=(100, 4, 32)).cuda() + self.weight = nn.Parameter(torch.empty(output_features, input_features)) + if bias: + self.bias = nn.Parameter(torch.empty(output_features)) + else: + self.register_parameter('bias', None) + torch.nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + torch.nn.init.zeros_(self.bias) + + def forward(self, x): + self.args.training = self.training + + return LinearFunction.apply(x, self.weight, self.bias, self.args) + + + +def test_linear8bit(): + l0 = torch.nn.Linear(32, 64).cuda().half() + l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() + l2 = Linear8bit(32, 64, args=get_args()).cuda().half() + l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() + + l0.weight.data = l2.weight.data.clone() + l0.bias.data = l2.bias.data.clone() + + l1.weight.data = l2.weight.data.clone() + l1.bias.data = l2.bias.data.clone() + + l3.weight.data = l2.weight.data.clone() + l3.bias.data = l2.bias.data.clone() + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + t = torch.randn(16, 8, 64, device='cuda').half() + b2 = b1.clone() + b3 = b1.clone() + b0 = b1.clone() + + o0 = l0(b0) + o1 = l1(b1) + o2 = l2(b2) + o3 = l3(b3) + + assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1) + assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1) + + loss0 = torch.nn.functional.mse_loss(o0, t) + loss1 = torch.nn.functional.mse_loss(o1, t) + loss2 = torch.nn.functional.mse_loss(o2, t) + loss3 = torch.nn.functional.mse_loss(o3, t) + + loss0.backward() + loss1.backward() + loss2.backward() + loss3.backward() + + assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + + err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() + err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() + err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() + + assert err1*0.8 < err2 + assert err2*0.8 < err3 + assert err3*0.8 < err1 + + l0.weight.grad = None + l1.weight.grad = None + l2.weight.grad = None + l3.weight.grad = None + l0.bias.grad = None + l1.bias.grad = None + l2.bias.grad = None + l3.bias.grad = None + + +threshold = [0.0, 3.0] +values = threshold +names = ['threshold_{0}'.format(vals) for vals in values] +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_inference(threshold): + l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() + assert l1.weight.device.type == 'cuda' + assert l1.weight.dtype == torch.float16 + + l1.eval() for i in range(100): - batch = batches[i] + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + if i == 1: + assert l1.state.CxB is not None + +def test_linear8bitlt_accumulated_gradient(): + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) + l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) + l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) + l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) + l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) + opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) + + acc_steps = 10 + - embedded1 = emb1(batch) - embedded2 = emb2(batch) + for i in range(10): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + o2 = l2(b1) + loss1 = o1.mean() + loss2 = o2.mean() + loss1.backward() + loss2.backward() + if i == 2: + assert l1[0].state.CxB is not None + assert l1[1].state.CxB is not None - l1 = embedded1.mean() - l2 = embedded2.mean() + if i > 0 and i % acc_steps == 0: + opt1.step() + opt1.zero_grad(True) + opt2.step() + opt2.zero_grad(True) + assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) + # we do this copy because otherwise we have small divergences over time that add up + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + else: + torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) + torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) - l1.backward() - l2.backward() - adam1.step() - adam2.step() +threshold = [0.0, 2.0] +values = threshold +names = ['threshold_{0}'.format(vals) for vals in values] +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_no_fp16_weights(threshold): + l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() + assert l1.weight.dtype == torch.int8 - adam1.zero_grad() - adam2.zero_grad() + l1.eval() + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + assert o1.dtype == torch.float16 + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 - assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 - assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda') + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == 'cuda' + assert mlp.fc2.weight.device.type == 'cuda' + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == 'cuda' + assert mlp.fc2.weight.device.type == 'cuda' diff --git a/tests/test_optim.py b/tests/test_optim.py index c80fe51..b173eaa 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,12 +1,9 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import os import time import shutil import uuid import pytest +import ctypes import torch import bitsandbytes as bnb import bitsandbytes.functional as F @@ -14,7 +11,9 @@ import bitsandbytes.functional as F from os.path import join from itertools import product -import apex +#import apex + +k = 20 def get_temp_dir(): path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) @@ -26,55 +25,47 @@ def rm_path(path): str2optimizers = {} str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) -str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) +#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) -str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) -str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) +#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) +#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) -str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW) -str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) -str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) +#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False)) str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) +#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True)) str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) -str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True)) str2statenames = {} str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] -str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['rmsprop'] = [('square_avg', 'state1')] -str2statenames['adagrad'] = [('sum', 'state1')] str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] -str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] -str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] +optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) if gtype == torch.float32: - atol, rtol = 2e-6, 1e-5 + atol, rtol = 1e-6, 1e-5 else: atol, rtol = 1e-4, 1e-3 - for i in range(50): + for i in range(k): g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) - if i % 10 == 0 and i > 0: + if i % (k//5) == 0 and i > 0: path = get_temp_dir() torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) del bnb_optimizer @@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) @@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype): else: atol, rtol = 1e-4, 1e-3 - original_p2 = p2[mask].clone() - for i in range(50): g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 @@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype): p2.grad = g2 p3.grad = g3 - if i > 30 and i % 10 == 0: - g1.data[mask] = 0.0 - g2.data[mask] = 0.0 - p1.grad = g1 - p2.grad = g2 - original_p1 = p1[mask].clone() - original_p2 = p2[mask].clone() - og_s1 = adam2.state[p2]['state1'][mask].clone() - og_s2 = adam2.state[p2]['state2'][mask].clone() - og_s11 = adam2.state[p1]['state1'][mask].clone() - og_s21 = adam2.state[p1]['state2'][mask].clone() - adam2.step() assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8 - if i > 30 and i % 10 == 0: - torch.testing.assert_allclose(original_p2, p2[mask]) - torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1) - torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2) - assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel() - assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0 - assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0 - - dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] +optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 - bnb_optimizer = str2optimizers[optim_name][1]([p1]) g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 p1.grad = g - for i in range(5000): - if i == 500: + for i in range(k): + if i == k//5: # 100 iterations for burn-in torch.cuda.synchronize() t0 = time.time() @@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): torch.cuda.synchronize() s = time.time()-t0 print('') - params = 4500*4096*4096 + params = (k-k//5)*dim1*dim2 print(optim_name, gtype, s/params) #assert s < 3.9 - -def test_str_betas(): - betas = (0.80, 0.95) - strbetas = '(0.80, 0.95)' - - layer = torch.nn.Linear(10, 10) - - base = bnb.optim.Adam(layer.parameters(), betas=betas) - strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas) - assert base.defaults['betas'][0] == 0.8 - assert base.defaults['betas'][1] == 0.95 - assert strbase.defaults['betas'][0] == 0.8 - assert strbase.defaults['betas'][1] == 0.95 - - -- cgit v1.2.3