summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_autograd.py270
-rw-r--r--tests/test_functional.py1763
-rw-r--r--tests/test_modules.py478
-rw-r--r--tests/test_optim.py87
4 files changed, 2446 insertions, 152 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
new file mode 100644
index 0000000..d2b5d59
--- /dev/null
+++ b/tests/test_autograd.py
@@ -0,0 +1,270 @@
+import pytest
+
+import torch
+import bitsandbytes as bnb
+
+from itertools import product
+
+n = 1
+k = 25
+dim1 = torch.randint(16,64, size=(n,)).tolist()
+dim2 = torch.randint(32,96, size=(n,)).tolist()
+dim3 = torch.randint(32,96, size=(n,)).tolist()
+dim4 = torch.randint(32,96, size=(n,)).tolist()
+funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
+str_funcs = ['bmm', 'matmul']
+req_grad = [(False, False), (True, False), (True, True), (False, True)]
+req_grad_str = ['FF', 'TF', 'TT', 'FT']
+transpose = [(False, False), (False, True), (True, True), (True, False)]
+str_transpose = ['FF', 'FT', 'TT', 'TF']
+dtype = [torch.float32, torch.float16]
+values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
+str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
+def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ dim4 = dim4 - (dim4 % 16)
+ for i in range(k):
+
+ # normal multiply
+ if funcs[0] in [torch.mm, torch.matmul]:
+ dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
+ dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
+ A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
+ B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
+ target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ torch.nn.init.xavier_uniform_(B)
+
+ if not transpose[0] and not transpose[1]:
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B)
+ elif not transpose[0] and transpose[1]:
+ out_torch = funcs[0](A, B.t())
+ out_bnb = funcs[1](A, B.t())
+ elif transpose[0] and not transpose[1]:
+ out_torch = funcs[0](A.t(), B)
+ out_bnb = funcs[1](A.t(), B)
+ elif transpose[0] and transpose[1]:
+ out_torch = funcs[0](A.t(), B.t())
+ out_bnb = funcs[1](A.t(), B.t())
+
+ n = out_bnb.numel()
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.0175
+ idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
+ assert (idx==0).sum().item() < n*0.001
+
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+ torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+
+ # batched matrix multiply
+ if funcs[0] in [torch.bmm, torch.matmul]:
+ A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
+ B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
+ target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ torch.nn.init.xavier_uniform_(B)
+
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B)
+
+ n = out_bnb.numel()
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.01
+ torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
+
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+
+ if funcs[0] in [torch.matmul]:
+ dim1 = dim1 - (dim1 % 16)
+ A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
+ dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
+ B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
+ target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ torch.nn.init.xavier_uniform_(B)
+
+ if transpose[1]:
+ out_torch = funcs[0](A, B.t())
+ out_bnb = funcs[1](A, B.t())
+ else:
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B)
+
+ n = out_bnb.numel()
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.0175
+ idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
+ assert (idx==0).sum().item() < n*0.001
+
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+
+
+n = 1
+k = 3
+dim1 = torch.randint(16,64, size=(n,)).tolist()
+dim2 = torch.randint(32,96, size=(n,)).tolist()
+dim3 = torch.randint(32,96, size=(n,)).tolist()
+dim4 = torch.randint(32,96, size=(n,)).tolist()
+
+#dim1 = (17,)
+#dim2 = (7,)
+#dim3 = (37,)
+#dim4 = (23,)
+
+decomp = [0.0, 6.0]
+funcs = [(torch.matmul, bnb.matmul)]
+str_funcs = ['matmul']
+req_grad = [(False, False), (True, False), (True, True), (False, True)]
+req_grad_str = ['FF', 'TF', 'TT', 'FT']
+transpose = [(False, True), (False, False)]
+str_transpose = ['NT', 'NN']
+dtype = [torch.float16]
+has_fp16_weights = [True, False]
+values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
+str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
+def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
+ dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
+ dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
+ outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
+
+ for i in range(k):
+
+ # normal multiply
+ if funcs[0] in [torch.mm, torch.matmul]:
+ A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
+ if decomp == 6.0:
+ with torch.no_grad():
+ A[:, outlier_dim] = 6.0
+ B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
+ target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
+ torch.nn.init.xavier_uniform_(B)
+ B2 = B.clone()
+
+ state = bnb.MatmulLtState()
+ state.threshold = decomp
+ state.has_fp16_weights = has_fp16_weights
+ if not has_fp16_weights:
+ if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
+ state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
+ B2 = state.CB
+
+ if not transpose[0] and transpose[1]:
+ out_torch = funcs[0](A, B.t())
+ out_bnb = funcs[1](A, B2, state=state)
+ elif not transpose[0] and not transpose[1]:
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B2.t(), state=state)
+
+ n = out_bnb.numel()
+ err = torch.abs(out_bnb-out_torch).mean().item()
+ #print(f'abs error {err:.4f}')
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.0175
+ idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
+ assert (idx==0).sum().item() < n*0.001
+
+ if has_fp16_weights:
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+ torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 2a7d308..6cbe58f 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,15 +1,76 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import pytest
+import math
+import random
+import time
import torch
import bitsandbytes as bnb
+import einops
from itertools import product
from bitsandbytes import functional as F
+torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
+k = 20
+
+def assert_all_approx_close(a, b, rtol, atol, count):
+ idx = torch.isclose(a, b, rtol, atol)
+ sumval = (idx==0).sum().item()
+ if sumval > count:
+ print(f'Too many values not close: assert {sumval} < {count}')
+ torch.testing.assert_allclose(a, b, rtol, atol)
+
+class FFN(torch.nn.Module):
+ def __init__(self, input_features, hidden_size, bias=True):
+ super(FFN, self).__init__()
+ self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
+ self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
+
+ with torch.no_grad():
+ torch.nn.init.xavier_uniform_(self.fc1.weight)
+ torch.nn.init.xavier_uniform_(self.fc2.weight)
+
+ def forward(self, x):
+ x = torch.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+class Timer(object):
+ def __init__(self):
+ self.starts = {}
+ self.ends = {}
+ self.agg = {}
+
+ def tick(self, name='default'):
+ if name not in self.starts:
+ self.starts[name] = torch.cuda.Event(enable_timing=True)
+ self.ends[name] = torch.cuda.Event(enable_timing=True)
+ self.starts[name].record()
+ else:
+ ms = self.tock(name, evict=True, print_ms=False)
+
+ def tock(self, name='default', evict=True, print_ms=True):
+ if name in self.ends:
+ self.ends[name].record()
+ torch.cuda.synchronize()
+ ms = self.starts[name].elapsed_time(self.ends[name])
+ if name not in self.agg: self.agg[name] = 0.0
+ self.agg[name] += ms
+ if evict:
+ self.starts.pop(name)
+ self.ends.pop(name)
+
+ if print_ms and name in self.agg:
+ print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0))
+
+ return self.agg[name]
+
+ def reset(self):
+ self.starts = {}
+ self.ends = {}
+ self.agg = {}
+ print('Resetting benchmark data')
+
def setup():
pass
@@ -64,8 +125,8 @@ def test_dynamic_quantization():
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
- print(sum(diffs)/len(diffs))
- print(sum(reldiffs)/len(reldiffs))
+ #print(sum(diffs)/len(diffs))
+ #print(sum(reldiffs)/len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda')
@@ -88,8 +149,8 @@ def test_dynamic_blockwise_quantization():
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
- print(sum(diffs)/len(diffs))
- print(sum(reldiffs)/len(reldiffs))
+ #print(sum(diffs)/len(diffs))
+ #print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
@@ -125,7 +186,7 @@ def test_percentile_clipping(gtype):
n = 4
step = 0
percentile=5
- for i in range(1000):
+ for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device='cuda')
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
@@ -145,69 +206,1653 @@ def test_percentile_clipping(gtype):
torch.testing.assert_allclose(gnorm1, gnorm2)
+def quant(x):
+ max1 = torch.abs(x).max()
+ x = torch.round(x/max1*127)
+ return max1, x.to(torch.int8)
+
+def dequant(c, maxC):
+ return c.float()*(maxC/127)
+
+def mm_dequant(maxA, maxB, C):
+ return C.float()*(maxA/127)*(maxB/127)
+
+def quant_multi(x, dim):
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ max1[max1==0] = 1.0
+ x = torch.round(x/max1*127)
+ return max1, x.to(torch.int8)
+
+def quant_multi_chunk(x, dim, chunk_size=32):
+ if dim==1:
+ x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size)
+ max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True)
+ max1 = torch.tile(max1, (1, 1, x.shape[1]))
+ max1 = max1.view(x.shape)
+ elif dim==0:
+ x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size)
+ max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
+ max1 = torch.tile(max1, (x.shape[0], 1, 1))
+ max1 = max1.view(x.shape)
+ max1[max1==0] = 1.0
+ x = torch.round(x/max1*127)
+ return max1, x.to(torch.int8)
+
+def quant_minmax(A):
+ minA = A.min()
+ maxA = A.max()
+
+def mean(xx):
+ return sum(xx)/float(len(xx))
+
+#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
+#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
+dim1 = [1024*2]
+dim2 = [1024*16]
+methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)]
+methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
+#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
+method_names = ['linear', 'vectorwise']
+batched = [False, True]
+values = list(product(dim1,dim2, methods, batched))
+values_names = list(product(dim1,dim2, method_names, batched))
+names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names]
+@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
+def test_approx_igemm(dim1, dim2, quant_methods, batched):
+ dim1 = dim1 - (dim1 % 32)
+ dim2 = dim2 - (dim2 % 32)
+ errors = []
+ relerrors = []
+ print('')
+ for i in range(5):
+ if batched:
+ A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda')
+ B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda')
+ maxA, Ac = quant_methods[0](A, 2)
+ maxB, Bc = quant_methods[1](B, 1)
+ else:
+ A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda')
+ B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda')
+ maxA, Ac = quant_methods[0](A, 1)
+ maxB, Bc = quant_methods[1](B, 0)
+ torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
+ if batched:
+ out2 = torch.bmm(A, B)
+ C = torch.bmm(Ac.float(), Bc.float())
+ else:
+ out2 = torch.mm(A, B)
+ C = F.igemm(Ac, Bc)
+ out = quant_methods[4](maxA, maxB, C)
+ std = out2.std()
+ out/= std
+ out2/= std
+ err = torch.abs(out-out2)
+ relerr = err/torch.abs(out2)
+ errors.append(err.mean().item())
+ relerrors.append(relerr.mean().item())
+ print(mean(errors))
+ print(mean(relerrors))
+
+
+
+
+
+
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
-def test_dynamic_blockwise_quantization_cpu():
- #A1 = torch.randn(1024, 1024, device='cpu')
- #code = F.create_dynamic_map()
- #for i in range(1000):
- # C, S = F.quantize_blockwise(A1, code=code)
- # A2 = F.dequantize_blockwise(C, S)
- for i in range(10):
- # equivalence with GPU blockwise quantization
- A1 = torch.randn(1024, 1024, device='cpu')
- C1, S1 = F.quantize_blockwise(A1)
- C2, S2 = F.quantize_blockwise(A1.cuda())
- torch.testing.assert_allclose(S1[0], S2[0].cpu())
- # there seems to be some issues with precision in CUDA vs CPU
- # not all elements are usually close, with couple off elements in a million
- idx = torch.isclose(C1, C2.cpu())
- assert (idx==0).sum().item() < 15
+n = 2
+hidden_dim = torch.randint(32,256, size=(n,)).tolist()
+batch_dim = torch.randint(16,256, size=(n,)).tolist()
+seq_dim = torch.randint(16,256, size=(n,)).tolist()
+transpose = [(False, False), (False, True), (True, False), (True, True)]
+values = list(product(hidden_dim,batch_dim, transpose, seq_dim))
+names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
+def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
+ hidden_dim = hidden_dim - (hidden_dim % 32)
+ batch_dim = batch_dim - (batch_dim % 16)
+ seq_dim = seq_dim - (seq_dim % 16)
+ for i in range(k):
+ shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
+ shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.float(), B.t().float())
+ out = F.igemm(A, B.t())
+ elif transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.t().float(), B.float())
+ out = F.igemm(A.t(), B)
+ elif transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.t().float(), B.t().float())
+ out = F.igemm(A.t(), B.t())
+ torch.testing.assert_allclose(out.float(), out2)
- diffs = []
- reldiffs = []
+ for i in range(k):
+ shapeA = (batch_dim, seq_dim, hidden_dim)
+ shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.float(), B.t().float())
+ out = F.igemm(A, B.t())
+
+ torch.testing.assert_allclose(out.float(), out2)
+
+
+n = 3
+seq_dim = torch.randint(32,512, size=(n,)).tolist()
+hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
+batch_dim = torch.randint(2,16, size=(n,)).tolist()
+values = list(product(seq_dim,hidden_dim,batch_dim))
+names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
+def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
+ seq_dim = seq_dim - (seq_dim % 32)
+ hidden_dim = hidden_dim - (hidden_dim % 32)
+ batch_dim = batch_dim - (batch_dim % 2)
+ for i in range(25):
+ A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8)
+ out2 = torch.einsum('bsi, bso->io', A.float(), B.float())
+ iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
+ out = F.igemm(A, B, out=iout)
+
+ torch.testing.assert_allclose(out.float(), out2)
+
+n = 2
+seq_dim = torch.randint(32,512, size=(n,)).tolist()
+hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
+batch_dim = torch.randint(2,16, size=(n,)).tolist()
+transpose = [False, True]
+values = list(product(seq_dim,hidden_dim,batch_dim, transpose))
+names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
+def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
+
+ def min_max(x):
+ maxA = torch.amax(x, dim=2, keepdim=True)
+ minA = torch.amin(x, dim=2, keepdim=True)
+ scale = (maxA-minA)/2.0
+ return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale
+
+ seq_dim = seq_dim - (seq_dim % 16)
+ hidden_dim = hidden_dim - (hidden_dim % 16)
+ batch_dim = batch_dim - (batch_dim % 2)
+ errs = []
+ relerrs = []
+ errs2 = []
+ relerrs2 = []
+ for i in range(k):
+ A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda')
+ if transpose:
+ B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda')
+ else:
+ B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda')
+ Ac, minA, scale = min_max(A)
+ if transpose:
+ maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
+ out = F.igemm(Ac, Bc.t())
+ out2 = torch.matmul(A,B.t())
+ offset = B.t().sum(0)*(minA+scale)
+ out = out.float()
+ out = (out*maxB.t()*scale/(127*127))+offset
+
+ maxA, Ac = quant_multi(A, dim=2)
+ out3 = F.igemm(Ac, Bc.t())
+ out3 = mm_dequant(maxA, maxB.t(), out3)
+ else:
+ maxB, Bc = quant_multi(B, dim=0)
+ offset = B.sum(0)*(minA+scale)
+ out = F.igemm(Ac, Bc)
+ out2 = torch.matmul(A,B)
+ out = out.float()
+ out = (out*maxB*scale/(127*127))+offset
+
+ maxA, Ac = quant_multi(A, dim=2)
+ out3 = F.igemm(Ac, Bc)
+ out3 = mm_dequant(maxA, maxB, out3)
+
+ std = out2.std()
+ out2 /= std
+ out /= std
+ out3 /= std
+
+ err = torch.abs(out-out2)
+ relerr = err/(torch.abs(out2)+1e-7)
+
+ err2 = torch.abs(out3-out2)
+ relerr2 = err2/(torch.abs(out2)+1e-7)
+
+ errs.append(err.mean().item())
+ relerrs.append(relerr.mean().item())
+ errs2.append(err2.mean().item())
+ relerrs2.append(relerr2.mean().item())
+ #print(mean(errs))
+ #print(mean(relerrs))
+ #print(mean(errs2))
+ #print(mean(relerrs2))
+ assert mean(errs) < 0.015
+ assert mean(relerrs) < 0.3
+
+n = 2
+dim1 = torch.randint(1,64, size=(n,)).tolist()
+dim2 = torch.randint(32,128, size=(n,)).tolist()
+dim3 = torch.randint(32,256, size=(n,)).tolist()
+dim4 = torch.randint(32,256, size=(n,)).tolist()
+transpose = [(False, False), (True, False), (False, True), (True, True)]
+values = list(product(dim1,dim2,dim3,dim4,transpose))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
+def test_ibmm(dim1, dim2, dim3, dim4, transpose):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ dim4 = dim4 - (dim4 % 16)
+ for i in range(k):
+ shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
+ shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.bmm(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
+ out = F.igemm(A, B.permute([0, 2, 1]))
+ elif transpose[0] and not transpose[1]:
+ out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
+ out = F.igemm(A.permute([0, 2, 1]), B)
+ elif transpose[0] and transpose[1]:
+ out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
+ out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
+ torch.testing.assert_allclose(out.float(), out2.float())
+
+n = 1
+dim1 = torch.randint(1,64, size=(n,)).tolist()
+dim2 = torch.randint(32,128, size=(n,)).tolist()
+dim3 = torch.randint(32,256, size=(n,)).tolist()
+values = list(product(dim1,dim2,dim3))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
+def test_vector_quant(dim1, dim2, dim3):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ for i in range(k):
+ A = torch.randn(size=(dim2, dim3), device='cuda')
+ qA, SA = F.vectorwise_quant(A, dim=0)
+ A1 = F.vectorwise_dequant(qA, SA)
+ torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)
+
+
+
+n = 2
+dim1 = torch.randint(2,256, size=(n,)).tolist()
+dim2 = torch.randint(2,256, size=(n,)).tolist()
+dim3 = torch.randint(2,256, size=(n,)).tolist()
+#dim1, dim2 = (256,), (256,)
+dtype = [torch.int8, torch.int32]
+a_order = ['row']
+out_order = ['col', 'row', 'col32']
+transpose = [False]
+dims = [2, 3]
+values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
+
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
+def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
+ if dims == 3 and out_order != 'col32': return
+ if dtype == torch.int32 and out_order != 'col32': return
+ func = F.get_transform_func(dtype, orderA, orderOut, transpose)
+
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+
+ out, S = F.nvidia_transform(A, to_order=orderOut)
+
+ if orderOut == 'row':
+ torch.testing.assert_allclose(A.flatten(), out.flatten())
+ elif orderOut == 'col':
+ torch.testing.assert_allclose(A.t().flatten(), out.flatten())
+ elif orderOut == 'col32':
+ if dims == 2:
+ n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32)))
+ elif dims == 3:
+ n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32)))
+ assert out.numel() == n
+ elif orderOut == 'col_turing':
+ # 32 col 8 row tiles
+ n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32)))
+ assert out.numel() == n
+ total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
+ for row in range(A.shape[0]):
+ for col in range(A.shape[1]):
+ i = row*A.shape[1]
+ j = col
+
+ coltile = (col // 32) + (1 if col % 32 != 0 else 0)
+ rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile
+ offset = 32*8*(rowtile+coltile)
+ col2 = col % 32
+ row2 = (row%8)*32
+
+
+ assert A.flatten()[i+j] == A[row, col]
+ #assert A.flatten()[i+j] == out.flatten()[row2+col2]
+ #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
+ #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
+
+ if orderOut == 'col32':
+ out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S)
+ torch.testing.assert_allclose(A, out2)
+
+
+n = 1
+dim1 = torch.randint(1,256, size=(n,)).tolist()
+dim2 = torch.randint(32,512, size=(n,)).tolist()
+dim3 = torch.randint(32,1024, size=(n,)).tolist()
+dim4 = torch.randint(32,1024, size=(n,)).tolist()
+
+#dim1 = [2]
+#dim2 = [2]
+#dim3 = [2]
+#dim4 = [2]
+
+dims = (2,3)
+ldb = [0]
+#ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
+def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ C1 = torch.matmul(A.float(), B.t().float())
+
+ A2, SA = F.transform(A, 'col32')
+ B2, SB = F.transform(B, 'col_turing')
+ C2, SC = F.igemmlt(A2, B2, SA, SB)
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ torch.testing.assert_allclose(C1, C3.float())
+
+ # transpose
+ B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ C1 = torch.matmul(A.float(), B.float())
+
+ B2t, SBt = F.transform(B, 'col_turing', transpose=True)
+ C2, SC = F.igemmlt(A2, B2t, SA, SBt)
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ torch.testing.assert_allclose(C1, C3.float())
+
+dim1 = [32]
+dim2 = [32]
+dim3 = [32]
+dim4 = [32]
+
+dims = (2,)
+#ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1,dim2,dim3,dim4,dims))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
+def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
+ formatB = F.get_special_format_str()
+ for i in range(k):
+ if dims == 2:
+ A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half()
+ elif dims == 3:
+ A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half()
+ B = torch.randn((dim4, dim3), device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ C1 = torch.matmul(A, B.t())
+ C2 = bnb.matmul(A, B.t())
+
+ A = A.view(-1, A.shape[-1])
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
+ C32A, SA = F.transform(CA, 'col32')
+ CxB, SB = F.transform(CB, to_order=formatB)
+ out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
+ output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
+
+ #print('')
+ #print(output.flatten()[:10])
+ #print(C1.flatten()[:10])
+ #print(C2.flatten()[:10])
+
+
+ #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
+
+ # transpose
+ #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ #C1 = torch.matmul(A.float(), B.float())
+
+ #B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
+ #C2, SC = F.igemmlt(A2, B2t, SA, SBt)
+ #C3, S = F.transform(C2, 'row', state=SC)
+ #torch.testing.assert_allclose(C1, C3.float())
+
+batch_size = 2
+seqdim = 512
+#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+
+
+#values = list(product(batch, seq, model, hidden))
+names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
+def test_bench_8bit_training(batch, seq, model, hidden):
+ formatB = F.get_special_format_str()
+ A = torch.randn(batch, seq, model, device='cuda').half()
+ grad = torch.randn(batch, seq, model, device='cuda').half()
+ w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half()
+ w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half()
+ print('')
+
+ #torch.cuda.synchronize()
+ ## warmup
+ #for i in range(100):
+ # torch.matmul(A, w1.t())
+ #torch.cuda.synchronize()
+
+ dtype = torch.int8
+ A = A.view(-1, A.shape[-1]).contiguous()
+ grad = grad.view(-1, grad.shape[-1]).contiguous()
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+
+ out1 = torch.matmul(A, w1.t()) # fc1
+ #out2 = torch.matmul(out1, w2.t())# fc2
+
+ #d1 = torch.matmul(grad, w2) # delta1
+ #d2 = torch.matmul(d1, w1) # delta2
+
+ #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
+ #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
+
+ torch.cuda.synchronize()
+ t16 = time.time() - t0
+ print(t16)
+
+ #torch.cuda.empty_cache()
+
+ #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+
+ #CTw1, Sw1 = F.transform2(Cw1, formatB)
+ #CTw2, Sw2 = F.transform2(Cw2, formatB)
+ #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+
+ #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ #C32A, SA = F.transform2(CA, 'col32')
+ ## fc1
+ #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
+ ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
+
+ ## fc2
+ #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
+ #C32out1, Sout1 = F.transform2(Cout1, 'col32')
+ #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
+ ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
+
+ ## delta1
+ #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
+ #C32grad, Sgrad = F.transform2(Cgrad, 'col32')
+ ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
+ ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
+
+ ## delta2
+ #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
+ #C32d1, Sd1 = F.transform2(Cd1, 'col32')
+ ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
+ ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
+
+ ## grad1
+ #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
+ #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
+ ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
+ ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
+
+ ## grad2
+ #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
+ #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
+ ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
+ ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
+
+ #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+
+ #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+
+ #CTw1, Sw1 = F.transform2(Cw1, formatB)
+ #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ #CTw2, Sw2 = F.transform2(Cw2, formatB)
+ #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ #torch.cuda.synchronize()
+ #t0 = time.time()
+ #for i in range(k):
+ # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # #CTw1, Sw1 = F.transform2(Cw1, formatB)
+ # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # #CTw1, Sw1 = F.transform2(Cw1, formatB)
+
+ # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
+ # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ # #CTw2, Sw2 = F.transform2(Cw2, formatB)
+ # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+
+ # C32A, SA = F.transform2(CA, 'col32')
+
+ # # fc1
+ # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
+ # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
+
+ # #print(coo_tensor.nnz)
+ # #out1sp = F.spmm_coo(coo_tensor, w1.t())
+ # #print(w1.t().shape)
+ # #out1 = out1dn + out1sp
+
+ # # fc2
+ # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
+ # C32out1, Sout1 = F.transform2(Cout1, 'col32')
+ # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
+ # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
+
+ # # delta1
+ # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
+ # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
+ # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
+ # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
+
+ # # delta2
+ # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
+ # C32d1, Sd1 = F.transform2(Cd1, 'col32')
+ # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
+ # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
+
+ # # grad1
+ # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
+ # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
+ # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
+ # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
+
+ # ## grad2
+ # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
+ # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
+ # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
+ # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
+
+ #torch.cuda.synchronize()
+ #t8 = time.time() - t0
+ #print(t8)
+
+
+
+
+
+n = 2
+dim1 = torch.randint(64,256, size=(n,)).tolist()
+dim4 = torch.randint(64,1024, size=(n,)).tolist()
+
+#dim1 = [2*1024]
+#dim4 = [2*1024]
+
+#dim1 = [4]
+#dim4 = [4]
+
+dims = (2,)
+#ldb = list(range(256, 1*1024, 256))
+formatB = ['col_turing', 'col_ampere']
+values = list(product(dim1,dim4,dims, formatB))
+names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
+def test_dequant_mm(dim1, dim4, dims, formatB):
+ inner = torch.randint(1, 128, size=(1,)).item()
+ formatB = F.get_special_format_str()
+ for i in range(k):
+ A = torch.randn(dim1, inner, device='cuda')
+ B = torch.randn(dim4, inner, device='cuda')
+ C1 = torch.matmul(A.half(), B.t().half())
+
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+ B1, maxB = F.vectorwise_quant(B, dim=1)
+
+ A2, SA = F.nvidia_transform(A1, 'col32')
+ B2, SB = F.nvidia_transform(B1, formatB)
+ C2, SC = F.igemmlt(A2, B2, SA, SB)
+
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
+
+ count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
+ n = C1.numel()
+ p = 0.06
+ assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}'
+
+ C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
+ torch.testing.assert_allclose(C5, C4)
+ #print(C2)
+
+
+
+n = 2
+dim1 = [1*1024]
+dim2 = [1*1024]
+#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+
+dims = (2,)
+#ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1,dim2,dims))
+names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
+def test_colrow_absmax(dim1, dim2, dims):
+ for i in range(k):
+ threshold = 3.0
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ A_truncated = A.clone()
+ A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
+ if dims == 2:
+ row_stats1, _ = torch.abs(A.float()).max(1)
+ col_stats1, _ = torch.abs(A.float()).max(0)
+ row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
+ col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
+ else:
+ assert False
+
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
+
+ A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4)
+ nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten()
+ nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device)
+ nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
+
+ torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
+ torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
+ torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
+
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
+
+ torch.testing.assert_allclose(col_stats1, col_stats2)
+ torch.testing.assert_allclose(row_stats1, row_stats2)
+ assert nnz_block_ptr2 is None
+
+
+
+n = 2
+#dim1 = [8*1024]
+#dim2 = [4*1024]
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+
+values = list(product(dim1,dim2))
+names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2", values, ids=names)
+def test_double_quant(dim1, dim2):
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ out_col1, Scol = F.vectorwise_quant(A, dim=0)
+ out_row1, Srow = F.vectorwise_quant(A, dim=1)
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+
+ # max difference is 1 due to rounding differences
+ torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
+ torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
+
+
+ n = CAt.numel()
+ num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item()
+ num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item()
+
+ # allow for 1:500 error due to rounding differences
+ min_error = 1/500
+ if num_not_close_cols > (min_error*n):
+ print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}')
+ assert False
+ if num_not_close_rows > (min_error*n):
+ print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}')
+ assert False
+
+ torch.testing.assert_allclose(Srow.flatten(), statsA)
+ torch.testing.assert_allclose(Scol.flatten(), statsAt)
+
+
+n = 4
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
+inner = torch.randint(1,4*1024, size=(n,)).tolist()
+
+dim1 = [6]
+dim4 = [4]
+inner = [8]
+
+values = list(zip(dim1, dim4, inner))
+names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+def test_integrated_igemmlt(dim1, dim4, inner):
+ for i in range(k):
+ A = torch.randn(dim1, inner, device='cuda').half()
+ B = torch.randn(dim4, inner, device='cuda').half()
+
+ out1 = torch.matmul(A.half(), B.t().half())
+
+ C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
+ C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+ B1, maxB = F.vectorwise_quant(B, dim=1)
+
+ torch.testing.assert_allclose(maxA.flatten(), stats1a)
+ torch.testing.assert_allclose(maxB.flatten(), stats2a)
+ torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
+ torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
+
+ A2, SA = F.nvidia_transform(C1a, 'col32')
+ B2, SB = F.nvidia_transform(C2a, 'col_turing')
+ outC32, SC = F.igemmlt(A2, B2, SA, SB)
+ out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
+
+ A2, SA = F.nvidia_transform(A1, 'col32')
+ B2, SB = F.nvidia_transform(B1, 'col_turing')
+ C2, SC = F.igemmlt(A2, B2, SA, SB)
+
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
+
+ err1 = torch.abs(out1-out2).mean().item()
+ err2 = torch.abs(out1-out3).mean().item()
+ assert err2 <= err1*1.01
+
+
+n = 6
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
+inner = torch.randint(1,4*1024, size=(n,)).tolist()
+
+values = list(zip(dim1, dim4, inner))
+names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+def test_igemmlt_row_scale(dim1, dim4, inner):
+ formatB = F.get_special_format_str()
+ err1, err2, err3 = [], [], []
+ relerr1, relerr2 = [], []
+ scale = 1
+ for i in range(k):
+ A = torch.randn(dim1, inner, device='cuda').half()
+ B = torch.randn(dim4, inner, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ C1 = torch.matmul(A, B.t())
+
+ out1 = torch.matmul(A.half(), B.t().half())
+
+
+ C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
+ CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
+ A2, SA = F.nvidia_transform(C1a, 'col32')
+ B2, SB = F.nvidia_transform(CB, formatB)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+
+ c = 10.0*inner*scale
+ row_scale = torch.ones_like(maxA)/c
+ outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ C3, S = F.nvidia_transform(outC32, 'row', state=SC)
+ maxval = torch.abs(C3).max()
+ if maxval == 127:
+ scale = 1.5
+ else:
+ scale = maxval/120
+ out3 = C3*maxA*absmaxB*c/(127*127)
+
+ C4 = torch.matmul(C1a.float(), CB.float().t())
+
+
+ C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
+ B2, SB = F.nvidia_transform(C2a, formatB)
+ outC32, SC = F.igemmlt(A2, B2, SA, SB)
+ out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
+
+ CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector')
+ CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+
+ C = torch.matmul(CA.float(), CB.t().float())
+ out4 = C*SA*SB/(127*127)
+ #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
+
+ #print('='*80)
+ #print(out1)
+ #print(out2)
+ #print(out3)
+
+ #print(out1)
+ #print(out2)
+ #print(out3)
+ err1.append(torch.abs(out1-out2).mean().item())
+ err2.append(torch.abs(out1-out3).mean().item())
+ err3.append(torch.abs(out1-out4).mean().item())
+
+ #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
+ print('')
+ print(sum(err1)/len(err1))
+ print(sum(err2)/len(err2))
+ print(sum(err3)/len(err3))
+
+
+dim1 = [1024, 2048]
+inner = [12288*4, 4096*4]
+dim4 = [12288, 4096]
+
+values = list(zip(dim1, dim4, inner))
+names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+def test_row_scale_bench(dim1, dim4, inner):
+ err1, err2, err3 = [], [], []
+ relerr1, relerr2 = [], []
+ scale = 1
+ A = torch.randn(dim1, inner, device='cuda').half()
+ B = torch.randn(dim4, inner, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ # warmpup
+ for i in range(k):
+ C1 = torch.matmul(A, B.t())
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ C1 = torch.matmul(A, B.t())
+ torch.cuda.synchronize()
+ print('16', time.time()-t0)
+
+ C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
+ CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
+ A2, SA = F.nvidia_transform(C1a, 'col32')
+ B2, SB = F.nvidia_transform(CB, formatB)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+
+ c = 10.0*inner*scale
+ row_scale = maxA/c
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ torch.cuda.synchronize()
+ print('row-wise', time.time()-t0)
+
+
+ C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
+ B2, SB = F.nvidia_transform(C2a, formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ outC32, SC = F.igemmlt(A2, B2, SA, SB)
+ torch.cuda.synchronize()
+ print('vector-wise', time.time()-t0)
+
+
+
+
+n = 2
+dim1 = torch.randint(2,1024, size=(n,)).tolist()
+dim2 = torch.randint(2,1024, size=(n,)).tolist()
+#dim1 = [8*1024]
+#dim2 = [4*1024]
+
+dim3 = [0]
+dtype = [torch.int8]
+a_order = ['row']
+out_order = ['col32', 'col_turing', 'col_ampere']
+transpose = [False, True]
+dims = [2]
+values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
+def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype)
+ elif dims == 3:
+ A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+
+ A.view(-1)[-1] = -1
+ if transpose:
+ At = A.t().contiguous()
+ out1, S1 = F.nvidia_transform(At, to_order=orderOut)
+ else:
+ out1, S1 = F.nvidia_transform(A, to_order=orderOut)
+ out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
+
+ assert S1[0][0] == S2[0][0]
+ assert S1[0][1] == S2[0][1]
+ #print(out1)
+ #print(out2)
+
+ torch.testing.assert_allclose(out1, out2)
+
+n = 2
+#dim1 = torch.randint(2,1024, size=(n,)).tolist()
+#dim2 = torch.randint(2,1024, size=(n,)).tolist()
+dim1 = [1]
+dim2 = [33]
+
+dtype = [torch.int8]
+#a_order = ['col_turing', 'col_ampere']
+a_order = ['col_turing']
+out_order = ['row']
+values = list(product(dim1,dim2,dtype, a_order, out_order))
+names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
+def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
+ for i in range(1):
+ A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype)
+
+ out2, S2 = F.transform(A, to_order=orderA)
+ A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2)
+ assert A2.shape[0] == A.shape[0]
+ assert A2.shape[1] == A.shape[1]
+
+
+ print('')
+ print(A)
+ print(out2)
+ print(A2)
+
+
+ #torch.testing.assert_allclose(A, A2)
+
+
+
+
+def test_overflow():
+ formatB = F.get_special_format_str()
+ for i in range(2):
+ a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
+ b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
+
+ Ca, Sa = F.nvidia_transform(a, 'col32')
+ Cb, Sb = F.nvidia_transform(b, formatB)
+
+ c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
+ c2 = torch.matmul(a.float(), b.float().t())
+
+
+n = 2
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+#dim1 = [4]
+#dim2 = [5]
+
+values = list(product(dim1,dim2))
+names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2", values, ids=names)
+def test_coo_double_quant(dim1, dim2):
+ threshold = 3.00
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device='cuda').half()
+
+ idx = (torch.abs(A) >= threshold)
+ CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+
+ if coo_tensor is not None:
+ A1 = A*idx
+ A2 = torch.zeros_like(A)
+ A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
+ torch.testing.assert_allclose(A1, A2)
+
+ A1 = A*(idx==0)
+ A2 = (CA.float()*statsA.unsqueeze(1)/127).half()
+ torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2)
+
+n = 2
+dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+dim2 = torch.randint(1,1*1024, size=(n,)).tolist()
+#dim1 = [7]
+#dim2 = [11]
+transposed_B = [False, True]
+values = list(product(dim1,dim2, transposed_B))
+names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
+def test_spmm_coo(dim1, dim2, transposed_B):
+ threshold = 1.5
+ dim3 = torch.randint(32, 128, size=(1,)).item()
+ #dim3 = 17
+ for i in range(k):
+ A = torch.randn(dim1, dim2).cuda().half()
+ if transposed_B:
+ B = torch.randn(dim3, dim2).cuda().half()
+ else:
+ B = torch.randn(dim2, dim3).cuda().half()
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+
+ if transposed_B:
+ out2 = F.spmm_coo(cooA, B.t())
+ out1 = torch.matmul(A2, B.t())
+ else:
+ out2 = F.spmm_coo(cooA, B)
+ out1 = torch.matmul(A2, B)
+
+ assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
+
+
+
+def test_spmm_bench():
+ batch = 2
+ model = 1024*1
+ hidden = model*4
+ seq = 1024
+ dim1 = batch*seq
+ dim2 = model
+ dim3 = hidden
+ threshold = 4
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ B = torch.randn(dim2, dim3, device='cuda').half()
for i in range(10):
- A1 = torch.randn(1024, 1024, device='cpu')
- C, S = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2)
- reldiff = diff/torch.abs(A1+1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- assert diffs[-1] < 0.011
- #print(sum(diffs)/len(diffs))
- #print(sum(reldiffs)/len(reldiffs))
+ C1 = bnb.matmul(A, B)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ C1 = bnb.matmul(A, B)
+ torch.cuda.synchronize()
+ t8 = time.time()-t0
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ print(nnz/idx.numel())
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- diffs = []
for i in range(10):
- A1 = torch.rand(1024, 1024, device='cpu')
- C, S = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2).mean().item()
- assert diff < 0.0033
- diffs.append(diff)
- torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
- #print(sum(diffs)/len(diffs))
+ out2 = F.spmm_coo(cooA, B)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ out2 = F.spmm_coo(cooA, B)
+ torch.cuda.synchronize()
+ tsp = time.time()-t0
+ print(tsp, t8)
+ print(tsp/t8)
+
+
+n = 2
+dim1 = torch.randint(256,1*1024, size=(n,)).tolist()
+dim2 = torch.randint(256,1*1024, size=(n,)).tolist()
+values = list(product(dim1,dim2))
+names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2", values, ids=names)
+def test_integrated_sparse_decomp(dim1, dim2):
+ threshold = 3.0
+ formatB = 'col_turing'
+ for i in range(k):
+ A = torch.randn(dim1, dim2).cuda().half()
+ w1 = torch.randn(dim1, dim2).cuda().half()
+ out1 = torch.matmul(A, w1.t())
+
+ Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ CTw1, Sw1 = F.transform(Cw1, formatB)
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ C32A, SA = F.transform(CA, 'col32')
+
+ out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
+ out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+ C32A, SA = F.transform(CA, 'col32')
+
+ out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
+ out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
+
+ assert coo_tensor is not None
+
+ out4 = F.spmm_coo(coo_tensor, w1.t())
+ out5 = out3 + out4
+
+ err1 = torch.abs(out1-out2).mean().item()
+ err2 = torch.abs(out1-out5).mean().item()
+ assert err2 < err1
+
+
+def test_matmuls():
+ a = torch.randn(256, 256).half().cuda()
+ b = torch.randn(256, 256).half().cuda()
+ c1 = torch.matmul(a, b)
+ c2 = bnb.matmul(a, b)
+ c3 = bnb.matmul(a, b)
+
+ err1 = torch.abs(c1-c2).mean().item()
+ err2 = torch.abs(c1-c3).mean().item()
+ assert err1 < 0.2
+ assert err2 < 0.2
+
+
+
+n = 2
+#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1*2048]
+dim2 = [12288]
+#dim1 = [32]
+#dim2 = [32]
+#dtype = [torch.float16, torch.int8]
+dtype = [torch.float16]
+out_function = ['zeros', 'ones']
+values = list(product(dim1,dim2, dtype, out_function))
+names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
+def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
+ out_func = getattr(torch, out_func)
+
+ threshold = 3.3
+ #threshold = 2.8
+ #threshold = 0.0
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ if dtype == torch.float16:
+ B = torch.randn(dim2, dim2*4, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ else:
+ B = torch.randn(dim2, dim2*4, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ B, SB = F.vectorwise_quant(B, quant_type='linear')
+ #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
+
+ print('')
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ out1 = torch.matmul(A2.half(), B.half())
+ out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
+ out1 += out.clone()
+ out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
+ #print(B)
+ #print(out1)
+ #print(out2)
+ p = 200/(2048*12288*4)
+ n = out1.numel()
+ count = math.ceil(p*n)
+ std = out1.std()
+ out1 /= std
+ out2 /= std
+ assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
+ #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
+
+ idx_col = torch.randint(0, A2.shape[-1], size=(15,))
+
+ #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
+
+ #Bt = torch.randn(dim2*4, dim2, device='cuda').half()
+ #torch.cuda.synchronize()
+ #t0 = time.time()
+ #print(A2.shape, B.shape)
+ #for i in range(100):
+ # #out3 = F.spmm_coo(cooA, Bt.t())
+ # #out2 = F.spmm_coo(cooA, B)
+ # #out2 = F.spmm_coo_very_sparse(cooA, B)
+ # #out1 = torch.matmul(A, Bt.t())
+
+ #torch.cuda.synchronize()
+ #print(time.time() - t0)
+
+def test_layout():
+ a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16)
+ a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte()
+ a2, s2 = F.transform(a1, 'col_turing')
+ print(a2.shape)
+
+ print(a1.flatten()[8*64:8*64+32])
+ for i in range(4):
+ print(a2.flatten()[i*8*32:i*8*32+32], 0)
+
+
+def test_coo2csr():
+ threshold = 1
+ A = torch.randn(128, 128).half().cuda()
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ csrA = F.coo2csr(cooA)
+ counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
+ assert counts.numel() == A.shape[0]
+
+ torch.testing.assert_allclose(counts, (A2!=0).sum(1))
+ idx = (A2!=0)
+ torch.testing.assert_allclose(A2[idx], csrA.values)
+
+
+def test_coo2csc():
+ threshold = 1
+ A = torch.randn(128, 128).half().cuda()
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ cscA = F.coo2csc(cooA)
+ counts = cscA.colptr[1:] - cscA.colptr[:-1]
+ assert counts.numel() == A.shape[1]
+
+ torch.testing.assert_allclose(counts, (A2!=0).sum(0))
+ # torch uses row-major -> use transpose to transfer to col-major
+ idx = (A2.t()!=0)
+ torch.testing.assert_allclose(A2.t()[idx], cscA.values)
+
+
+
+n = 2
+#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1*2048]
+#dim2 = [12288]
+dim2 = [2048]
+#dim1 = [2]
+#dim2 = [2]
+dtype = [torch.int8]
+values = list(product(dim1,dim2, dtype))
+names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
+def test_spmm_coo_dequant(dim1, dim2, dtype):
+ threshold = 6.0
+ #threshold = 2.8
+ #threshold = 0.0
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16)
+ torch.nn.init.xavier_uniform_(B)
+ Bt = B.t().contiguous()
+
+
+ CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
+
+ rowidx = torch.randint(0, A.shape[-1], size=(15,))
+
+ A[:, rowidx] = 8.0
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out1 = torch.matmul(A2, B.half())
+ out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
+ out3 = out3*statsBt.half()/127
+
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ offset = counts.cumsum(0).int()
+ max_count, max_idx = torch.sort(counts, descending=True)
+ print(torch.median(max_count.float()))
+
+ torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
+
+ p = 200/(2048*12288*4)
+ n = out1.numel()
+ count = math.ceil(p*n)
+ assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
+
+
+
+ #torch.cuda.synchronize()
+ #t0 = time.time()
+ #for i in range(100):
+ # out2 = F.spmm_coo_very_sparse(cooA, B)
+ #torch.cuda.synchronize()
+ #print('fp16', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo(cooA, B)
+ torch.cuda.synchronize()
+ print('cusparse fp16', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo_very_sparse(cooA, CBt)
+ torch.cuda.synchronize()
+ print('int8', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ torch.cuda.synchronize()
+ print('int8+dequant', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = torch.matmul(A, B)
+ torch.cuda.synchronize()
+ print('matmul', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out = out1+out2
+ torch.cuda.synchronize()
+ print('sparse+ matmul', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
+ torch.cuda.synchronize()
+ print('partial matmul', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ torch.cuda.synchronize()
+ print('partial matmul', time.time() - t0)
+
+batch_size = 1
+seqdim = 2048
+values = []
+values.append((batch_size, seqdim, 768, 4*768))
+#values.append((batch_size, seqdim, 1024, 4*1024))
+#values.append((batch_size, seqdim, 1536, 4*1536))
+#values.append((batch_size, seqdim, 2048, 4*2048))
+#values.append((batch_size, seqdim, 2560, 4*2560))
+#values.append((batch_size, seqdim, 4096, 4*4096))
+#values.append((batch_size, seqdim, 5140, 4*5140))
+#values.append((batch_size, seqdim, 12288, 4*12288))
+names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
+def test_bench_matmul(batch, seq, model, hidden):
+ formatB = F.get_special_format_str()
+
+ A = torch.randn(batch, seq, model, device='cuda').half()
+ B = torch.empty(hidden, model, dtype=torch.float16, device='cuda')
+ torch.nn.init.xavier_uniform_(B)
+
+ linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
+ linear8bit.eval()
+
+ outliers = torch.randint(0, model, size=(5,)).cuda()
+ A[:, :, outliers] = 8.0
+
+ linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
+ linearMixedBit.eval()
+
+ # warmup
+ for i in range(100):
+ torch.matmul(A, B.t())
+ torch.cuda.synchronize()
+ print('')
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ torch.matmul(A, B.t())
+ torch.cuda.synchronize()
+ print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ bnb.matmul(A, B)
+ torch.cuda.synchronize()
+ print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
+ C32A, SA = F.transform(CA, 'col32')
+ CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
+ CxB, SB = F.transform(CB, to_order=formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
+ torch.cuda.synchronize()
+ print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ BA, statsB = F.vectorwise_quant(B, dim=1)
+ CxB, SB = F.nvidia_transform(CB, to_order=formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ A2 = A.view(-1, A.shape[-1]).contiguous()
+ CA, statsA = F.vectorwise_quant(A2, dim=1)
+ C32A, SA = F.nvidia_transform(CA, 'col32')
+ out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
+ Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
+ F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
+ torch.cuda.synchronize()
+ print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+ CxB, SB = F.nvidia_transform(CB, to_order=formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ A2 = A.view(-1, A.shape[-1]).contiguous()
+ CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear')
+ C32A, SA = F.nvidia_transform(CA, 'col32')
+ out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
+ Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
+ out = Cout*statsB*statsA*(1.0/(127*127))
+ torch.cuda.synchronize()
+ print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ linear8bit(A)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ linear8bit(A)
+ torch.cuda.synchronize()
+ print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+
+ linearMixedBit(A)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ linearMixedBit(A)
+ torch.cuda.synchronize()
+ print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+
+def test_zeropoint():
+ def min_max(x):
+ maxA = torch.amax(x, dim=1, keepdim=True)
+ minA = torch.amin(x, dim=1, keepdim=True)
+ midpoint = (maxA-minA)/2.0
+ dyna = 252/(maxA-minA)
+ #dyna *= 0.98
+ x = dyna*x
+ x = x - torch.round((dyna*(minA+midpoint)))
+ return x.to(torch.int8), minA, midpoint, dyna
+ batch = 2
+ seq = 2
+ model = 4
+ hidden = 2*model
+ #batch = 4
+ #seq = 2048
+ #model = 1024
+ #hidden = 8*model
+ A = torch.randn(batch*seq, model, device='cuda').half()-0.4
+ B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half())
+
+ #A[0] = 0
+ #B[:, 0] = 0
+ #A = A*(A>0)
+ #A[0, 0] = 0
+ #A[0, 0] = 6.0
+
+ Ac, minA, midpoint, dyna = min_max(A)
+ #print(Ac[0, 0], 'zero')
+ #print(Ac, Ac.min(), Ac.max())
+ Bc, maxB = F.vectorwise_quant(B, quant_type='linear')
+ out = F.igemm(Ac, Bc)
+ out2 = torch.matmul(A,B)
+ offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna
+ out = out.float()
+ #print(out.shape, maxB.shape, scale.shape, offset.shape)
+ norm1 = maxB/127
+ C4 = (out/dyna)*norm1+offset
+
+
+ B1 = torch.nn.Parameter(B.clone())
+ B2 = torch.nn.Parameter(B.clone())
+ B3 = torch.nn.Parameter(B.clone())
+ B4 = torch.nn.Parameter(B.clone())
+
+
+ C1 = torch.matmul(A, B1)
+ C2 = bnb.matmul_cublas(A, B2, None, 'linear')
+ C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint')
+ C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint')
+
+ err1 = torch.abs(C1-C2).mean().item()
+ err2 = torch.abs(C1-C3).mean().item()
+ err3 = torch.abs(C1-C4).mean().item()
+ print(err1, err2, err3)
+ #assert err1 > err2
+
+ loss1 = C1.mean()
+ loss2 = C2.mean()
+ loss3 = C3.mean()
+ loss4 = C4.mean()
+
+ loss1.backward()
+ loss2.backward()
+ loss3.backward()
+ loss4.backward()
+
+ print(B.grad)
+ print(B1.grad)
+ print(B2.grad)
+ print(B3.grad)
+ print(B4.grad)
+ err1 = torch.abs(B1.grad-B2.grad).mean().item()
+ err2 = torch.abs(B1.grad-B3.grad).mean().item()
+ err3 = torch.abs(B1.grad-B4.grad).mean().item()
+ print(err1, err2, err3)
+
+
+
+
+def test_zp():
+ def quant_zp(x):
+ dtype = x.dtype
+ x = x.float()
+ dyna = x.max() - x.min()
+ if dyna == 0: dyna = 1
+ qx = 254./dyna
+ minx = x.min()
+ #zpx = torch.round(minx* qx)
+ #zpx = 127 - torch.round(x.max()* qx)
+ zpx = torch.round(x.min()* qx) - 127
+ x = (qx*x) + zpx
+ return x, qx, zpx
+ batch = 2
+ seq = 512
+ model = 1024
+ hidden = 4*model
+ A = torch.randn(batch*seq, model, device='cuda').half()*0.1
+ B = torch.randn(model, hidden, device='cuda').half()*0.1
+
+
+ C0 = torch.matmul(A, B)
+
+
+ #A, SA = F.vectorwise_quant(A, quant_type='linear')
+ #B, SB = F.vectorwise_quant(B, quant_type='linear')
+ A = A.float()
+ B = B.float()
+
+ C1 = torch.matmul(A, B)
+ C3 = bnb.matmul(A.half(), B.t().contiguous().half())
+
+ zp = 1
+ #C2 = torch.matmul(A-zp, B)
+ #C2 += B.sum(0).view(1, -1)*zp
+ C2 = torch.matmul(A, B-zp)
+ C2 -= A.sum(1).view(-1, 1)*zp
+
+ ca, cqa, cza = quant_zp(A)
+ print(ca.min(), ca.max())
+ print((ca-cza).min(), (ca-cza).max())
+
+ zp = 1
+ scale = 2.0
+ C5 = torch.matmul((A*scale)-zp, B)
+ C5 += B.sum(0)*zp
+ C5 /= scale
+
+ CA, qa, zpa = quant_zp(A)
+ C4 = torch.matmul(CA, B)
+ C4 -= B.sum(0)*zpa
+ C4 /= qa
+ zpb = 1
+ zpa = 1
+ qa = 2
+ qb = 2
+ C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb)
+ C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
+ C6 -= zpa*zpb*A.shape[1]
+ C6 /= qa*qb
-def test_histogram():
- dim1, dim2 = 32, 32
- source = torch.rand(dim1, dim2, device='cuda')
- idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
- idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
- histogram1 = torch.zeros((256, 256)).cuda()
- histogram2 = torch.zeros((256, 256)).cuda()
+ CA, qa, zpa = quant_zp(A)
+ CB, qb, zpb = quant_zp(B)
+ C7 = torch.matmul(CA, CB)
+ C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
+ C7 -= zpa*zpb*A.shape[1]
+ C7 /= qa*qb
- F.histogram_scatter_add_2d(histogram2, idx1, idx2, source)
+ print('')
+ #print(C0.flatten()[:10])
+ print(C1.flatten()[:10])
+ print(C2.flatten()[:10])
+ print(C3.flatten()[:10])
+ print(C5.flatten()[:10])
+ print(C6.flatten()[:10])
+ print(C7.flatten()[:10])
+ err1 = torch.abs(C1-C2).mean().item()
+ err2 = torch.abs(C1-C3).mean().item()
+ err3 = torch.abs(C1-C4).mean().item()
+ err4 = torch.abs(C1-C5).mean().item()
+ err5 = torch.abs(C1-C6).mean().item()
+ err6 = torch.abs(C1-C7).mean().item()
+ print(err1, err2, err3, err4, err5, err6)
- for i in range(dim1):
- for j in range(dim2):
- histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j]
- torch.testing.assert_allclose(histogram1, histogram2)
- torch.testing.assert_allclose(histogram1.sum(), source.sum())
diff --git a/tests/test_modules.py b/tests/test_modules.py
index a0379cb..a2c950b 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -1,42 +1,470 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import pytest
import torch
+
+from itertools import product
+from torch import nn
+
import bitsandbytes as bnb
+class MockArgs(object):
+ def __init__(self, initial_data):
+ for key in initial_data:
+ setattr(self, key, initial_data[key])
+
+class MLP8bit(torch.nn.Module):
+ def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
+ super(MLP8bit, self).__init__()
+ self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
+ self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x
+
+
+def get_args():
+ args = MockArgs([])
+ args.quant_type = 'vector'
+ args.use_8bit_training = 'full'
+ args.clip_freq = 9999
+ return args
+
+def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
+ idx = torch.isclose(a, b, rtol, atol)
+ sumval = (idx==0).sum().item()
+ if sumval > count:
+ print(f'Too many values not close: assert {sumval} < {count}')
+ torch.testing.assert_allclose(a, b, rtol, atol)
+
+class LinearFunction(torch.autograd.Function):
+
+ @staticmethod
+ def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ norm = math.sqrt(math.pi)/math.sqrt(2.0)
+ #std = torch.abs(x).mean()*norm
+ std = torch.std(x)
+ max1 = std*trim_value
+ x = x/max1*127
+ x = round_func(x)
+ x[x > 127] = 127
+ x[x < -127] = -127
+ x = x/127*max1
+
+ return x
+
+ def quant(x, quant_type, dim=1):
+ if quant_type == 'linear':
+ max1 = torch.abs(x).max().float()
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'vector':
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'min-max':
+ maxA = torch.amax(x, dim=dim, keepdim=True).float()
+ minA = torch.amin(x, dim=dim, keepdim=True).float()
+ scale = (maxA-minA)/2.0
+ xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
+ return xq, (minA.float(), scale.float())
+ else: return None
+
+ def dequant(xq, S1, S2, dtype, quant_type):
+ if quant_type == 'linear':
+ norm = S1*S2/(127*127)
+ # double cast needed to prevent overflows
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'vector':
+ x = xq.float()
+ if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
+ if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
+ #print(x.shape, S1.shape, S2.shape)
+ if len(S1.shape) == 2:
+ x *= S1.t()/127
+ else:
+ x *= S1/127
+ x *= S2/127
+ return x.to(dtype)
+ else: return None
+
+ def dequant_min_max(xq, A, B, SA, SB, dtype):
+ offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ x = xq.float()
+ if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
+ if len(SB.shape) == 2:
+ x *= SB.t()/127
+ else:
+ x *= SB/127
+ x *= SA[1]/127
+ x +=offset
+ return x.to(dtype)
+
+
+ def get_8bit_linear(x, stochastic=False):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ max1 = torch.abs(x).max()
+ x = x/max1*127
+ x = round_func(x)/127*max1
+ #x = torch.round(x)/128*max1
+ return x
+
+ @staticmethod
+ def get_8bit_vector_wise(x, dim, stochastic=False):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ max1[max1==0] = 1.0
+ x = (x*127)/max1
+ x = round_func(x)/127*max1
+ return x
+
+ @staticmethod
+ def round_stoachastic(x):
+ sign = torch.sign(x)
+ absx = torch.abs(x)
+ decimal = absx-torch.floor(absx)
+ rdm = torch.rand_like(decimal)
+ return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
+
+ @staticmethod
+ def fake_8bit_storage(w, exponent_bits):
+ code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
+ absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
+ out = bnb.functional.dequantize_blockwise(absmax, C, code)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_quantile(w, args):
+ code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
+ #C = bnb.functional.quantize_no_absmax(code, w)
+ #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
+ #print(out)
+ #out = out.half()
+ code /= torch.max(torch.abs(code))
+ absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
+ out = bnb.functional.dequantize_blockwise(absmax, C, code)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_stoachstic(w):
+ rand = torch.rand(1024, device=w.device)
+ absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
+ out = bnb.functional.dequantize_blockwise(absmax, C)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_with_max(w, topk=8):
+ blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
+ max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
+ idx = idx[:, :topk]
+ max_val = max_val[:, :topk]
+
+ mask = torch.zeros_like(blocked_w)
+ mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
+ mask = mask.bool()
+
+ # 1. zero out max values
+ # 2. quantize + dequantize
+ # 3. write back max values
+ # 4. copy matrix back to weight
+
+ values = blocked_w[mask]
+ blocked_w[mask] = 0
+
+ code = bnb.functional.create_dynamic_map()
+ code = code.to(w.device)
+ absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
+ bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
+
+ blocked_w[mask] = values
+
+ unblocked_w = blocked_w.flatten().view(w.shape)
+
+ w.copy_(unblocked_w)
+ return unblocked_w
+
+
+ @staticmethod
+ def forward(ctx, x, weight, bias=None, args=None):
+ if args.use_8bit_training != 'off':
+ weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
+ outputq = bnb.functional.igemm(x8, weight8.t())
+ output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
+ #if torch.rand(1) < 0.01:
+ #output32 = torch.matmul(x, weight.t())
+ #err = torch.abs(output-output32).float()
+ #relerr = err/(torch.abs(output32).float()+1e-8)
+ #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
+ else:
+ #output = torch.matmul(x, weight.t())
+ output = torch.einsum('bsi,oi->bso', x, weight)
+
+ ctx.save_for_backward(x, weight, bias)
+ ctx.args = args
+
+ if bias is not None:
+ output += bias.unsqueeze(0).expand_as(output)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, weight, bias = ctx.saved_tensors
+ args = ctx.args
+ stochastic = False
+ grad_input = grad_weight = grad_bias = None
+ if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
+
+ # weight and x are already 8bit
+ # -> transform grad_output to 8-bit
+ if args.use_8bit_training == 'forward+wgrad':
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
+ grad_weight8 = bnb.functional.igemm(grad_output8, x8)
+ grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+
+ #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
+
+ grad_input = grad_output.matmul(weight)
+ elif args.use_8bit_training == 'full':
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
+ grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
+ bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
+ grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
+ weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
+ grad_input8 = bnb.functional.igemm(grad_output8, weight8)
+ grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
+
+ else:
+ grad_input = grad_output.matmul(weight)
+ grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
-@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
-def test_embeddings(embcls):
- bnb.optim.GlobalOptimManager.get_instance().initialize()
- emb1 = torch.nn.Embedding(100, 512).cuda()
- emb2 = embcls(100, 512).cuda()
+ return grad_input, grad_weight, grad_bias, None
- adam1 = bnb.optim.Adam8bit(emb1.parameters())
- adam2 = bnb.optim.Adam8bit(emb2.parameters())
+class Linear8bit(nn.Module):
+ def __init__(self, input_features, output_features, bias=True, args=None):
+ super(Linear8bit, self).__init__()
+ self.input_features = input_features
+ self.output_features = output_features
+ self.args = args
- batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
+ self.weight = nn.Parameter(torch.empty(output_features, input_features))
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_features))
+ else:
+ self.register_parameter('bias', None)
+ torch.nn.init.xavier_uniform_(self.weight)
+ if self.bias is not None:
+ torch.nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ self.args.training = self.training
+
+ return LinearFunction.apply(x, self.weight, self.bias, self.args)
+
+
+
+def test_linear8bit():
+ l0 = torch.nn.Linear(32, 64).cuda().half()
+ l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
+ l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
+ l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
+
+ l0.weight.data = l2.weight.data.clone()
+ l0.bias.data = l2.bias.data.clone()
+
+ l1.weight.data = l2.weight.data.clone()
+ l1.bias.data = l2.bias.data.clone()
+
+ l3.weight.data = l2.weight.data.clone()
+ l3.bias.data = l2.bias.data.clone()
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ t = torch.randn(16, 8, 64, device='cuda').half()
+ b2 = b1.clone()
+ b3 = b1.clone()
+ b0 = b1.clone()
+
+ o0 = l0(b0)
+ o1 = l1(b1)
+ o2 = l2(b2)
+ o3 = l3(b3)
+
+ assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
+ assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)
+
+ loss0 = torch.nn.functional.mse_loss(o0, t)
+ loss1 = torch.nn.functional.mse_loss(o1, t)
+ loss2 = torch.nn.functional.mse_loss(o2, t)
+ loss3 = torch.nn.functional.mse_loss(o3, t)
+
+ loss0.backward()
+ loss1.backward()
+ loss2.backward()
+ loss3.backward()
+
+ assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+ assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+
+ err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
+ err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
+ err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
+
+ assert err1*0.8 < err2
+ assert err2*0.8 < err3
+ assert err3*0.8 < err1
+
+ l0.weight.grad = None
+ l1.weight.grad = None
+ l2.weight.grad = None
+ l3.weight.grad = None
+ l0.bias.grad = None
+ l1.bias.grad = None
+ l2.bias.grad = None
+ l3.bias.grad = None
+
+
+threshold = [0.0, 3.0]
+values = threshold
+names = ['threshold_{0}'.format(vals) for vals in values]
+@pytest.mark.parametrize("threshold", values, ids=names)
+def test_linear8bitlt_inference(threshold):
+ l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
+ assert l1.weight.device.type == 'cuda'
+ assert l1.weight.dtype == torch.float16
+
+ l1.eval()
for i in range(100):
- batch = batches[i]
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ if i == 1:
+ assert l1.state.CxB is not None
+
+def test_linear8bitlt_accumulated_gradient():
+ l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
+ l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
+ l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
+ l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
+ l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
+ l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
+ opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
+ opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
+
+ acc_steps = 10
+
- embedded1 = emb1(batch)
- embedded2 = emb2(batch)
+ for i in range(10):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ o2 = l2(b1)
+ loss1 = o1.mean()
+ loss2 = o2.mean()
+ loss1.backward()
+ loss2.backward()
+ if i == 2:
+ assert l1[0].state.CxB is not None
+ assert l1[1].state.CxB is not None
- l1 = embedded1.mean()
- l2 = embedded2.mean()
+ if i > 0 and i % acc_steps == 0:
+ opt1.step()
+ opt1.zero_grad(True)
+ opt2.step()
+ opt2.zero_grad(True)
+ assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
+ assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
+ # we do this copy because otherwise we have small divergences over time that add up
+ l1[0].weight.data.copy_(l2[0].weight.data)
+ l1[1].weight.data.copy_(l2[1].weight.data)
+ else:
+ torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
+ torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
- l1.backward()
- l2.backward()
- adam1.step()
- adam2.step()
+threshold = [0.0, 2.0]
+values = threshold
+names = ['threshold_{0}'.format(vals) for vals in values]
+@pytest.mark.parametrize("threshold", values, ids=names)
+def test_linear8bitlt_no_fp16_weights(threshold):
+ l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ assert l1.weight.dtype == torch.int8
- adam1.zero_grad()
- adam2.zero_grad()
+ l1.eval()
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ assert o1.dtype == torch.float16
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
- assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
- assert adam2.state[emb2.weight]['state1'].dtype == torch.float32
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+ assert mlp.fc1.weight.device.type == 'cuda'
+ assert mlp.fc2.weight.device.type == 'cuda'
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+ assert mlp.fc1.weight.device.type == 'cuda'
+ assert mlp.fc2.weight.device.type == 'cuda'
diff --git a/tests/test_optim.py b/tests/test_optim.py
index c80fe51..b173eaa 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -1,12 +1,9 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import os
import time
import shutil
import uuid
import pytest
+import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -14,7 +11,9 @@ import bitsandbytes.functional as F
from os.path import join
from itertools import product
-import apex
+#import apex
+
+k = 20
def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
@@ -26,55 +25,47 @@ def rm_path(path):
str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
-str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
-str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
+#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
-str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
+#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
+#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
-str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW)
-str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
+#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
-str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
+#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False))
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
+#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
-str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
-str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['rmsprop'] = [('square_avg', 'state1')]
-str2statenames['adagrad'] = [('sum', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
-str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
-str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
+optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
- atol, rtol = 2e-6, 1e-5
+ atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
- for i in range(50):
+ for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
- if i % 10 == 0 and i > 0:
+ if i % (k//5) == 0 and i > 0:
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
del bnb_optimizer
@@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
@@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
else:
atol, rtol = 1e-4, 1e-3
- original_p2 = p2[mask].clone()
-
for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
@@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
p2.grad = g2
p3.grad = g3
- if i > 30 and i % 10 == 0:
- g1.data[mask] = 0.0
- g2.data[mask] = 0.0
- p1.grad = g1
- p2.grad = g2
- original_p1 = p1[mask].clone()
- original_p2 = p2[mask].clone()
- og_s1 = adam2.state[p2]['state1'][mask].clone()
- og_s2 = adam2.state[p2]['state2'][mask].clone()
- og_s11 = adam2.state[p1]['state1'][mask].clone()
- og_s21 = adam2.state[p1]['state2'][mask].clone()
-
adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8
- if i > 30 and i % 10 == 0:
- torch.testing.assert_allclose(original_p2, p2[mask])
- torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
- torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
- assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
- assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
- assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
-
-
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
+optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
-
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g
- for i in range(5000):
- if i == 500:
+ for i in range(k):
+ if i == k//5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
@@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize()
s = time.time()-t0
print('')
- params = 4500*4096*4096
+ params = (k-k//5)*dim1*dim2
print(optim_name, gtype, s/params)
#assert s < 3.9
-
-def test_str_betas():
- betas = (0.80, 0.95)
- strbetas = '(0.80, 0.95)'
-
- layer = torch.nn.Linear(10, 10)
-
- base = bnb.optim.Adam(layer.parameters(), betas=betas)
- strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
- assert base.defaults['betas'][0] == 0.8
- assert base.defaults['betas'][1] == 0.95
- assert strbase.defaults['betas'][0] == 0.8
- assert strbase.defaults['betas'][1] == 0.95
-
-