From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- tests/test_functional.py | 1453 +++++++++++++++++++++++++--------------------- 1 file changed, 801 insertions(+), 652 deletions(-) (limited to 'tests/test_functional.py') diff --git a/tests/test_functional.py b/tests/test_functional.py index bfc3e28..11cd198 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,25 +1,29 @@ -import pytest import math import random import time -import torch -import bitsandbytes as bnb -import einops - from itertools import product +import einops +import pytest +import torch + +import bitsandbytes as bnb from bitsandbytes import functional as F -torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) +torch.set_printoptions( + precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 +) k = 20 + def assert_all_approx_close(a, b, rtol, atol, count): idx = torch.isclose(a, b, rtol, atol) - sumval = (idx==0).sum().item() + sumval = (idx == 0).sum().item() if sumval > count: - print(f'Too many values not close: assert {sumval} < {count}') + print(f"Too many values not close: assert {sumval} < {count}") torch.testing.assert_allclose(a, b, rtol, atol) + class FFN(torch.nn.Module): def __init__(self, input_features, hidden_size, bias=True): super(FFN, self).__init__() @@ -35,13 +39,14 @@ class FFN(torch.nn.Module): x = self.fc2(x) return x + class Timer(object): def __init__(self): self.starts = {} self.ends = {} self.agg = {} - def tick(self, name='default'): + def tick(self, name="default"): if name not in self.starts: self.starts[name] = torch.cuda.Event(enable_timing=True) self.ends[name] = torch.cuda.Event(enable_timing=True) @@ -49,66 +54,70 @@ class Timer(object): else: ms = self.tock(name, evict=True, print_ms=False) - def tock(self, name='default', evict=True, print_ms=True): + def tock(self, name="default", evict=True, print_ms=True): if name in self.ends: self.ends[name].record() torch.cuda.synchronize() ms = self.starts[name].elapsed_time(self.ends[name]) - if name not in self.agg: self.agg[name] = 0.0 + if name not in self.agg: + self.agg[name] = 0.0 self.agg[name] += ms if evict: self.starts.pop(name) self.ends.pop(name) if print_ms and name in self.agg: - print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0)) + print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0)) return self.agg[name] def reset(self): - self.starts = {} + self.starts = {} self.ends = {} self.agg = {} - print('Resetting benchmark data') + print("Resetting benchmark data") + def setup(): pass + def teardown(): pass -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half']) + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_estimate_quantiles(dtype): - A = torch.rand(1024, 1024, device='cuda') + A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) code = F.estimate_quantiles(A) - percs = torch.linspace(1/512, 511/512, 256, device=A.device) + percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) - A = torch.randn(1024, 1024, device='cuda') + A = torch.randn(1024, 1024, device="cuda") A = A.to(dtype) code = F.estimate_quantiles(A) quantiles = torch.quantile(A.float(), percs) - diff = torch.abs(code-quantiles) + diff = torch.abs(code - quantiles) assert (diff > 5e-02).sum().item() == 0 def test_quantile_quantization(): for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") code = F.estimate_quantiles(A1) C = F.quantize_no_absmax(A1, code) A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() assert diff < 0.0075 - A1 = torch.rand(1024, 1024, device='cuda') + A1 = torch.rand(1024, 1024, device="cuda") code = F.estimate_quantiles(A1) C = F.quantize_no_absmax(A1, code) A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) assert diff < 0.001 @@ -117,22 +126,22 @@ def test_dynamic_quantization(): diffs = [] reldiffs = [] for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") C, S = F.quantize(A1) A2 = F.dequantize(C, S) - diff = torch.abs(A1-A2) - reldiff = diff/torch.abs(A1+1e-8) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - #print(sum(diffs)/len(diffs)) - #print(sum(reldiffs)/len(reldiffs)) + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) for i in range(100): - A1 = torch.rand(1024, 1024, device='cuda') + A1 = torch.rand(1024, 1024, device="cuda") C, S = F.quantize(A1) A2 = F.dequantize(C, S) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) assert diff < 0.004 @@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization(): diffs = [] reldiffs = [] for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") C, S = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2) - reldiff = diff/torch.abs(A1+1e-8) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diffs[-1] < 0.011 - #print(sum(diffs)/len(diffs)) - #print(sum(reldiffs)/len(reldiffs)) + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): - A1 = torch.rand(1024, 1024, device='cuda') + A1 = torch.rand(1024, 1024, device="cuda") C, S = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() assert diff < 0.0033 diffs.append(diff) torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - #print(sum(diffs)/len(diffs)) + # print(sum(diffs)/len(diffs)) + def test_dynamic_blockwise_stochastic_quantization(): diffs = [] reldiffs = [] rand = torch.rand(1024).cuda() for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") C1, S1 = F.quantize_blockwise(A1, rand=rand) C2, S2 = F.quantize_blockwise(A1) # a maximunm distance of quantized values of 1 torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) - fraction_smaller = (C1C2).float().sum()/C1.numel() - torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) + fraction_smaller = (C1 < C2).float().sum() / C1.numel() + fraction_larger = (C1 > C2).float().sum() / C1.numel() + torch.testing.assert_allclose( + fraction_larger, fraction_smaller, atol=0.01, rtol=0 + ) - -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half']) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_percentile_clipping(gtype): - gnorm_vec1 = torch.zeros(100, device='cuda') - gnorm_vec2 = torch.zeros(100, device='cuda') + gnorm_vec1 = torch.zeros(100, device="cuda") + gnorm_vec2 = torch.zeros(100, device="cuda") n = 4 step = 0 - percentile=5 + percentile = 5 for i in range(k): step += 1 - g = torch.randn(n, n, dtype=gtype, device='cuda') - gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) - assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1 + g = torch.randn(n, n, dtype=gtype, device="cuda") + gnorm1, clip2, gnorm_scale = F.percentile_clipping( + g, gnorm_vec2, step, percentile=percentile + ) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 gnorm2 = torch.norm(g.float()) if step == 1: @@ -208,74 +221,89 @@ def test_percentile_clipping(gtype): def quant(x): max1 = torch.abs(x).max() - x = torch.round(x/max1*127) + x = torch.round(x / max1 * 127) return max1, x.to(torch.int8) + def dequant(c, maxC): - return c.float()*(maxC/127) + return c.float() * (maxC / 127) + def mm_dequant(maxA, maxB, C): - return C.float()*(maxA/127)*(maxB/127) + return C.float() * (maxA / 127) * (maxB / 127) + def quant_multi(x, dim): max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - max1[max1==0] = 1.0 - x = torch.round(x/max1*127) + max1[max1 == 0] = 1.0 + x = torch.round(x / max1 * 127) return max1, x.to(torch.int8) + def quant_multi_chunk(x, dim, chunk_size=32): - if dim==1: - x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size) - max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True) + if dim == 1: + x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True) max1 = torch.tile(max1, (1, 1, x.shape[1])) max1 = max1.view(x.shape) - elif dim==0: - x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size) + elif dim == 0: + x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size) max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) max1 = torch.tile(max1, (x.shape[0], 1, 1)) max1 = max1.view(x.shape) - max1[max1==0] = 1.0 - x = torch.round(x/max1*127) + max1[max1 == 0] = 1.0 + x = torch.round(x / max1 * 127) return max1, x.to(torch.int8) + def quant_minmax(A): minA = A.min() maxA = A.max() + def mean(xx): - return sum(xx)/float(len(xx)) + return sum(xx) / float(len(xx)) + -#dim1 = torch.randint(1,1024*4, size=(4,)).tolist() -#dim2 = torch.randint(1,1024*4, size=(4,)).tolist() -dim1 = [1024*2] -dim2 = [1024*16] -methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)] +# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() +# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() +dim1 = [1024 * 2] +dim2 = [1024 * 16] +methods = [ + (lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant) +] methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) -#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) -method_names = ['linear', 'vectorwise'] +# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) +method_names = ["linear", "vectorwise"] batched = [False, True] -values = list(product(dim1,dim2, methods, batched)) -values_names = list(product(dim1,dim2, method_names, batched)) -names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names] +values = list(product(dim1, dim2, methods, batched)) +values_names = list(product(dim1, dim2, method_names, batched)) +names = [ + "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names +] + + @pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) def test_approx_igemm(dim1, dim2, quant_methods, batched): dim1 = dim1 - (dim1 % 32) dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - print('') + print("") for i in range(5): if batched: - A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda') - B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda') + A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") + B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 2) maxB, Bc = quant_methods[1](B, 1) else: - A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda') - B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda') + A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") + B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) + torch.testing.assert_allclose( + quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 + ) if batched: out2 = torch.bmm(A, B) C = torch.bmm(Ac.float(), Bc.float()) @@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): C = F.igemm(Ac, Bc) out = quant_methods[4](maxA, maxB, C) std = out2.std() - out/= std - out2/= std - err = torch.abs(out-out2) - relerr = err/torch.abs(out2) + out /= std + out2 /= std + err = torch.abs(out - out2) + relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) print(mean(errors)) print(mean(relerrors)) - - - - def test_stable_embedding(): layer = bnb.nn.StableEmbedding(1024, 1024) layer.reset_parameters() - n = 2 -hidden_dim = torch.randint(32,256, size=(n,)).tolist() -batch_dim = torch.randint(16,256, size=(n,)).tolist() -seq_dim = torch.randint(16,256, size=(n,)).tolist() +hidden_dim = torch.randint(32, 256, size=(n,)).tolist() +batch_dim = torch.randint(16, 256, size=(n,)).tolist() +seq_dim = torch.randint(16, 256, size=(n,)).tolist() transpose = [(False, False), (False, True), (True, False), (True, True)] -values = list(product(hidden_dim,batch_dim, transpose, seq_dim)) -names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values] +values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) +names = [ + "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): - shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) - shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + shapeA = ( + (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + ) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: out2 = torch.matmul(A.float(), B.float()) out = F.igemm(A, B) @@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: out2 = torch.matmul(A.float(), B.float()) out = F.igemm(A, B) @@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): n = 3 -seq_dim = torch.randint(32,512, size=(n,)).tolist() -hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() -batch_dim = torch.randint(2,16, size=(n,)).tolist() -values = list(product(seq_dim,hidden_dim,batch_dim)) -names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values] +seq_dim = torch.randint(32, 512, size=(n,)).tolist() +hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() +batch_dim = torch.randint(2, 16, size=(n,)).tolist() +values = list(product(seq_dim, hidden_dim, batch_dim)) +names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): seq_dim = seq_dim - (seq_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 2) for i in range(25): - A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8) - out2 = torch.einsum('bsi, bso->io', A.float(), B.float()) + A = torch.randint( + -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ).to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to( + torch.int8 + ) + out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) out = F.igemm(A, B, out=iout) torch.testing.assert_allclose(out.float(), out2) + n = 2 -seq_dim = torch.randint(32,512, size=(n,)).tolist() -hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() -batch_dim = torch.randint(2,16, size=(n,)).tolist() +seq_dim = torch.randint(32, 512, size=(n,)).tolist() +hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() +batch_dim = torch.randint(2, 16, size=(n,)).tolist() transpose = [False, True] -values = list(product(seq_dim,hidden_dim,batch_dim, transpose)) -names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values] +values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) +names = [ + "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): - def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) minA = torch.amin(x, dim=2, keepdim=True) - scale = (maxA-minA)/2.0 - return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale + scale = (maxA - minA) / 2.0 + return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale seq_dim = seq_dim - (seq_dim % 16) hidden_dim = hidden_dim - (hidden_dim % 16) @@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda') + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") if transpose: - B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda') + B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: - B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda') + B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") Ac, minA, scale = min_max(A) if transpose: maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) out = F.igemm(Ac, Bc.t()) - out2 = torch.matmul(A,B.t()) - offset = B.t().sum(0)*(minA+scale) + out2 = torch.matmul(A, B.t()) + offset = B.t().sum(0) * (minA + scale) out = out.float() - out = (out*maxB.t()*scale/(127*127))+offset + out = (out * maxB.t() * scale / (127 * 127)) + offset maxA, Ac = quant_multi(A, dim=2) out3 = F.igemm(Ac, Bc.t()) out3 = mm_dequant(maxA, maxB.t(), out3) else: maxB, Bc = quant_multi(B, dim=0) - offset = B.sum(0)*(minA+scale) + offset = B.sum(0) * (minA + scale) out = F.igemm(Ac, Bc) - out2 = torch.matmul(A,B) + out2 = torch.matmul(A, B) out = out.float() - out = (out*maxB*scale/(127*127))+offset + out = (out * maxB * scale / (127 * 127)) + offset maxA, Ac = quant_multi(A, dim=2) out3 = F.igemm(Ac, Bc) @@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): out /= std out3 /= std - err = torch.abs(out-out2) - relerr = err/(torch.abs(out2)+1e-7) + err = torch.abs(out - out2) + relerr = err / (torch.abs(out2) + 1e-7) - err2 = torch.abs(out3-out2) - relerr2 = err2/(torch.abs(out2)+1e-7) + err2 = torch.abs(out3 - out2) + relerr2 = err2 / (torch.abs(out2) + 1e-7) errs.append(err.mean().item()) relerrs.append(relerr.mean().item()) errs2.append(err2.mean().item()) relerrs2.append(relerr2.mean().item()) - #print(mean(errs)) - #print(mean(relerrs)) - #print(mean(errs2)) - #print(mean(relerrs2)) + # print(mean(errs)) + # print(mean(relerrs)) + # print(mean(errs2)) + # print(mean(relerrs2)) assert mean(errs) < 0.015 assert mean(relerrs) < 0.3 + n = 2 -dim1 = torch.randint(1,64, size=(n,)).tolist() -dim2 = torch.randint(32,128, size=(n,)).tolist() -dim3 = torch.randint(32,256, size=(n,)).tolist() -dim4 = torch.randint(32,256, size=(n,)).tolist() +dim1 = torch.randint(1, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 128, size=(n,)).tolist() +dim3 = torch.randint(32, 256, size=(n,)).tolist() +dim4 = torch.randint(32, 256, size=(n,)).tolist() transpose = [(False, False), (True, False), (False, True), (True, True)] -values = list(product(dim1,dim2,dim3,dim4,transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, dim3, dim4, transpose)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) def test_ibmm(dim1, dim2, dim3, dim4, transpose): dim2 = dim2 - (dim2 % 16) @@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): for i in range(k): shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: out2 = torch.bmm(A.float(), B.float()) @@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_allclose(out.float(), out2.float()) + n = 1 -dim1 = torch.randint(1,64, size=(n,)).tolist() -dim2 = torch.randint(32,128, size=(n,)).tolist() -dim3 = torch.randint(32,256, size=(n,)).tolist() -values = list(product(dim1,dim2,dim3)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values] +dim1 = torch.randint(1, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 128, size=(n,)).tolist() +dim3 = torch.randint(32, 256, size=(n,)).tolist() +values = list(product(dim1, dim2, dim3)) +names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) def test_vector_quant(dim1, dim2, dim3): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) for i in range(k): - A = torch.randn(size=(dim2, dim3), device='cuda') + A = torch.randn(size=(dim2, dim3), device="cuda") qA, SA = F.vectorwise_quant(A, dim=0) A1 = F.vectorwise_dequant(qA, SA) torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1) - n = 2 -dim1 = torch.randint(2,256, size=(n,)).tolist() -dim2 = torch.randint(2,256, size=(n,)).tolist() -dim3 = torch.randint(2,256, size=(n,)).tolist() -#dim1, dim2 = (256,), (256,) +dim1 = torch.randint(2, 256, size=(n,)).tolist() +dim2 = torch.randint(2, 256, size=(n,)).tolist() +dim3 = torch.randint(2, 256, size=(n,)).tolist() +# dim1, dim2 = (256,), (256,) dtype = [torch.int8, torch.int32] -a_order = ['row'] -out_order = ['col', 'row', 'col32'] +a_order = ["row"] +out_order = ["col", "row", "col32"] transpose = [False] dims = [2, 3] -values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) +values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) + +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format( + *vals + ) + for vals in values +] + -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +@pytest.mark.parametrize( + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names +) def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and out_order != 'col32': return - if dtype == torch.int32 and out_order != 'col32': return + if dims == 3 and out_order != "col32": + return + if dtype == torch.int32 and out_order != "col32": + return func = F.get_transform_func(dtype, orderA, orderOut, transpose) if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype) + A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) out, S = F.nvidia_transform(A, to_order=orderOut) - if orderOut == 'row': + if orderOut == "row": torch.testing.assert_allclose(A.flatten(), out.flatten()) - elif orderOut == 'col': + elif orderOut == "col": torch.testing.assert_allclose(A.t().flatten(), out.flatten()) - elif orderOut == 'col32': + elif orderOut == "col32": if dims == 2: - n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32))) + n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32))) + n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) assert out.numel() == n - elif orderOut == 'col_turing': + elif orderOut == "col_turing": # 32 col 8 row tiles - n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32))) + n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( + A.shape[1] + (32 - (A.shape[1] % 32)) + ) assert out.numel() == n total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) for row in range(A.shape[0]): for col in range(A.shape[1]): - i = row*A.shape[1] + i = row * A.shape[1] j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile - offset = 32*8*(rowtile+coltile) + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile + offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 - row2 = (row%8)*32 + row2 = (row % 8) * 32 + assert A.flatten()[i + j] == A[row, col] + # assert A.flatten()[i+j] == out.flatten()[row2+col2] + # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) + # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) - assert A.flatten()[i+j] == A[row, col] - #assert A.flatten()[i+j] == out.flatten()[row2+col2] - #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) - #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) - - if orderOut == 'col32': - out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S) + if orderOut == "col32": + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) torch.testing.assert_allclose(A, out2) n = 1 -dim1 = torch.randint(1,256, size=(n,)).tolist() -dim2 = torch.randint(32,512, size=(n,)).tolist() -dim3 = torch.randint(32,1024, size=(n,)).tolist() -dim4 = torch.randint(32,1024, size=(n,)).tolist() +dim1 = torch.randint(1, 256, size=(n,)).tolist() +dim2 = torch.randint(32, 512, size=(n,)).tolist() +dim3 = torch.randint(32, 1024, size=(n,)).tolist() +dim4 = torch.randint(32, 1024, size=(n,)).tolist() -#dim1 = [2] -#dim2 = [2] -#dim3 = [2] -#dim4 = [2] +# dim1 = [2] +# dim2 = [2] +# dim3 = [2] +# dim4 = [2] -dims = (2,3) +dims = (2, 3) ldb = [0] -#ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( + torch.int8 + ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + torch.int8 + ) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - A2, SA = F.transform(A, 'col32') - B2, SB = F.transform(B, 'col_turing') + A2, SA = F.transform(A, "col32") + B2, SB = F.transform(B, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_allclose(C1, C3.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.float()) - B2t, SBt = F.transform(B, 'col_turing', transpose=True) + B2t, SBt = F.transform(B, "col_turing", transpose=True) C2, SC = F.igemmlt(A2, B2t, SA, SBt) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_allclose(C1, C3.float()) + dim1 = [32] dim2 = [32] dim3 = [32] dim4 = [32] dims = (2,) -#ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1,dim2,dim3,dim4,dims)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dim3, dim4, dims)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): formatB = F.get_special_format_str() for i in range(k): if dims == 2: - A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half() + A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half() - B = torch.randn((dim4, dim3), device='cuda').half() + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() + B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) C2 = bnb.matmul(A, B.t()) @@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") CxB, SB = F.transform(CB, to_order=formatB) out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) - #print('') - #print(output.flatten()[:10]) - #print(C1.flatten()[:10]) - #print(C2.flatten()[:10]) + # print('') + # print(output.flatten()[:10]) + # print(C1.flatten()[:10]) + # print(C2.flatten()[:10]) - - #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # transpose - #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) - #C1 = torch.matmul(A.float(), B.float()) + # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + # C1 = torch.matmul(A.float(), B.float()) + + # B2t, SBt = F.transform2(B, 'col_turing', transpose=True) + # C2, SC = F.igemmlt(A2, B2t, SA, SBt) + # C3, S = F.transform(C2, 'row', state=SC) + # torch.testing.assert_allclose(C1, C3.float()) - #B2t, SBt = F.transform2(B, 'col_turing', transpose=True) - #C2, SC = F.igemmlt(A2, B2t, SA, SBt) - #C3, S = F.transform(C2, 'row', state=SC) - #torch.testing.assert_allclose(C1, C3.float()) batch_size = 2 seqdim = 512 -#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] -values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +values = [ + (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), + (batch_size, seqdim, 5120, 3 * 5120), + (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), +] + + +# values = list(product(batch, seq, model, hidden)) +names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] -#values = list(product(batch, seq, model, hidden)) -names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_8bit_training(batch, seq, model, hidden): formatB = F.get_special_format_str() - A = torch.randn(batch, seq, model, device='cuda').half() - grad = torch.randn(batch, seq, model, device='cuda').half() - w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half() - w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half() - print('') + A = torch.randn(batch, seq, model, device="cuda").half() + grad = torch.randn(batch, seq, model, device="cuda").half() + w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half() + w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half() + print("") - #torch.cuda.synchronize() + # torch.cuda.synchronize() ## warmup - #for i in range(100): + # for i in range(100): # torch.matmul(A, w1.t()) - #torch.cuda.synchronize() + # torch.cuda.synchronize() dtype = torch.int8 A = A.view(-1, A.shape[-1]).contiguous() @@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden): t0 = time.time() for i in range(k): - out1 = torch.matmul(A, w1.t()) # fc1 - #out2 = torch.matmul(out1, w2.t())# fc2 + out1 = torch.matmul(A, w1.t()) # fc1 + # out2 = torch.matmul(out1, w2.t())# fc2 - #d1 = torch.matmul(grad, w2) # delta1 - #d2 = torch.matmul(d1, w1) # delta2 + # d1 = torch.matmul(grad, w2) # delta1 + # d2 = torch.matmul(d1, w1) # delta2 - #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 - #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 + # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 + # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 torch.cuda.synchronize() t16 = time.time() - t0 print(t16) - #torch.cuda.empty_cache() + # torch.cuda.empty_cache() - #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - #CTw1, Sw1 = F.transform2(Cw1, formatB) - #CTw2, Sw2 = F.transform2(Cw2, formatB) - #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - #C32A, SA = F.transform2(CA, 'col32') + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # C32A, SA = F.transform2(CA, 'col32') ## fc1 - #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) ## fc2 - #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) - #C32out1, Sout1 = F.transform2(Cout1, 'col32') - #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) ## delta1 - #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) - #C32grad, Sgrad = F.transform2(Cgrad, 'col32') + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) ## delta2 - #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) - #C32d1, Sd1 = F.transform2(Cd1, 'col32') + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) ## grad1 - #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) - #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) ## grad2 - #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) - #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + # C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) - #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - #CTw1, Sw1 = F.transform2(Cw1, formatB) - #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - #CTw2, Sw2 = F.transform2(Cw2, formatB) - #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(k): + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(k): # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) # #CTw1, Sw1 = F.transform2(Cw1, formatB) # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) @@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden): # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) - #torch.cuda.synchronize() - #t8 = time.time() - t0 - #print(t8) - - - + # torch.cuda.synchronize() + # t8 = time.time() - t0 + # print(t8) n = 2 -dim1 = torch.randint(64,256, size=(n,)).tolist() -dim4 = torch.randint(64,1024, size=(n,)).tolist() +dim1 = torch.randint(64, 256, size=(n,)).tolist() +dim4 = torch.randint(64, 1024, size=(n,)).tolist() -#dim1 = [2*1024] -#dim4 = [2*1024] +# dim1 = [2*1024] +# dim4 = [2*1024] -#dim1 = [4] -#dim4 = [4] +# dim1 = [4] +# dim4 = [4] dims = (2,) -#ldb = list(range(256, 1*1024, 256)) -formatB = ['col_turing', 'col_ampere'] -values = list(product(dim1,dim4,dims, formatB)) -names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +formatB = ["col_turing", "col_ampere"] +values = list(product(dim1, dim4, dims, formatB)) +names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) def test_dequant_mm(dim1, dim4, dims, formatB): inner = torch.randint(1, 128, size=(1,)).item() formatB = F.get_special_format_str() for i in range(k): - A = torch.randn(dim1, inner, device='cuda') - B = torch.randn(dim4, inner, device='cuda') + A = torch.randn(dim1, inner, device="cuda") + B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - A2, SA = F.nvidia_transform(A1, 'col32') + A2, SA = F.nvidia_transform(A1, "col32") B2, SB = F.nvidia_transform(B1, formatB) C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() n = C1.numel() p = 0.06 - assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}' + assert ( + count / n < p + ), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) torch.testing.assert_allclose(C5, C4) - #print(C2) - + # print(C2) n = 2 -dim1 = [1*1024] -dim2 = [1*1024] -#dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 1024] +dim2 = [1 * 1024] +# dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() dims = (2,) -#ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1,dim2,dims)) -names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dims)) +names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) def test_colrow_absmax(dim1, dim2, dims): for i in range(k): threshold = 3.0 - A = torch.randn(dim1, dim2, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() A_truncated = A.clone() A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 if dims == 2: @@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims): else: assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) - - A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4) - nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten() - nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=threshold + ) + + A_blocked = einops.rearrange( + torch.abs(A), + "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size", + row_tiles=16, + block_size=64 * 4, + ) + nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten() + nnz_block_ptr1 = torch.zeros( + nnz_rows1_counts.shape[0] + 1, + dtype=nnz_rows1_counts.dtype, + device=nnz_rows1_counts.device, + ) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) torch.testing.assert_allclose(col_stats1_trunc, col_stats2) @@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims): assert nnz_block_ptr2 is None - n = 2 -#dim1 = [8*1024] -#dim2 = [4*1024] -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +# dim1 = [8*1024] +# dim2 = [4*1024] +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() + +values = list(product(dim1, dim2)) +names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] + -values = list(product(dim1,dim2)) -names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) def test_double_quant(dim1, dim2): for i in range(k): - A = torch.randn(dim1, dim2, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() out_col1, Scol = F.vectorwise_quant(A, dim=0) out_row1, Srow = F.vectorwise_quant(A, dim=1) @@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2): torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) - n = CAt.numel() - num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item() - num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item() + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() # allow for 1:500 error due to rounding differences - min_error = 1/500 - if num_not_close_cols > (min_error*n): - print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}') + min_error = 1 / 500 + if num_not_close_cols > (min_error * n): + print( + f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" + ) assert False - if num_not_close_rows > (min_error*n): - print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}') + if num_not_close_rows > (min_error * n): + print( + f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" + ) assert False torch.testing.assert_allclose(Srow.flatten(), statsA) @@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2): n = 4 -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim4 = torch.randint(1,4*1024, size=(n,)).tolist() -inner = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() dim1 = [6] dim4 = [4] inner = [8] values = list(zip(dim1, dim4, inner)) -names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): - A = torch.randn(dim1, inner, device='cuda').half() - B = torch.randn(dim4, inner, device='cuda').half() + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() out1 = torch.matmul(A.half(), B.t().half()) @@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner): torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) - A2, SA = F.nvidia_transform(C1a, 'col32') - B2, SB = F.nvidia_transform(C2a, 'col_turing') + A2, SA = F.nvidia_transform(C1a, "col32") + B2, SB = F.nvidia_transform(C2a, "col_turing") outC32, SC = F.igemmlt(A2, B2, SA, SB) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) - A2, SA = F.nvidia_transform(A1, 'col32') - B2, SB = F.nvidia_transform(B1, 'col_turing') + A2, SA = F.nvidia_transform(A1, "col32") + B2, SB = F.nvidia_transform(B1, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) - err1 = torch.abs(out1-out2).mean().item() - err2 = torch.abs(out1-out3).mean().item() - assert err2 <= err1*1.01 + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out3).mean().item() + assert err2 <= err1 * 1.01 n = 6 -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim4 = torch.randint(1,4*1024, size=(n,)).tolist() -inner = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() values = list(zip(dim1, dim4, inner)) -names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): @@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner): relerr1, relerr2 = [], [] scale = 1 for i in range(k): - A = torch.randn(dim1, inner, device='cuda').half() - B = torch.randn(dim4, inner, device='cuda').half() + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) out1 = torch.matmul(A.half(), B.t().half()) - C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) - CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') - A2, SA = F.nvidia_transform(C1a, 'col32') + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(CB, formatB) A1, maxA = F.vectorwise_quant(A, dim=1) - c = 10.0*inner*scale - row_scale = torch.ones_like(maxA)/c + c = 10.0 * inner * scale + row_scale = torch.ones_like(maxA) / c outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) - C3, S = F.nvidia_transform(outC32, 'row', state=SC) + C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: scale = 1.5 else: - scale = maxval/120 - out3 = C3*maxA*absmaxB*c/(127*127) + scale = maxval / 120 + out3 = C3 * maxA * absmaxB * c / (127 * 127) C4 = torch.matmul(C1a.float(), CB.float().t()) - C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) outC32, SC = F.igemmlt(A2, B2, SA, SB) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) - CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector') - CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear') + CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") + CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear") C = torch.matmul(CA.float(), CB.t().float()) - out4 = C*SA*SB/(127*127) - #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) + out4 = C * SA * SB / (127 * 127) + # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) - #print('='*80) - #print(out1) - #print(out2) - #print(out3) + # print('='*80) + # print(out1) + # print(out2) + # print(out3) - #print(out1) - #print(out2) - #print(out3) - err1.append(torch.abs(out1-out2).mean().item()) - err2.append(torch.abs(out1-out3).mean().item()) - err3.append(torch.abs(out1-out4).mean().item()) + # print(out1) + # print(out2) + # print(out3) + err1.append(torch.abs(out1 - out2).mean().item()) + err2.append(torch.abs(out1 - out3).mean().item()) + err3.append(torch.abs(out1 - out4).mean().item()) - #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) - print('') - print(sum(err1)/len(err1)) - print(sum(err2)/len(err2)) - print(sum(err3)/len(err3)) + # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) + print("") + print(sum(err1) / len(err1)) + print(sum(err2) / len(err2)) + print(sum(err3) / len(err3)) dim1 = [1024, 2048] -inner = [12288*4, 4096*4] +inner = [12288 * 4, 4096 * 4] dim4 = [12288, 4096] values = list(zip(dim1, dim4, inner)) -names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.skip("Row scale has some bugs for ampere") def test_row_scale_bench(dim1, dim4, inner): err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] scale = 1 - A = torch.randn(dim1, inner, device='cuda').half() - B = torch.randn(dim4, inner, device='cuda').half() + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() torch.nn.init.xavier_uniform_(B) # warmpup for i in range(k): @@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner): for i in range(k): C1 = torch.matmul(A, B.t()) torch.cuda.synchronize() - print('16', time.time()-t0) + print("16", time.time() - t0) C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) - CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') - A2, SA = F.nvidia_transform(C1a, 'col32') + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(CB, formatB) A1, maxA = F.vectorwise_quant(A, dim=1) - c = 10.0*inner*scale - row_scale = maxA/c + c = 10.0 * inner * scale + row_scale = maxA / c torch.cuda.synchronize() t0 = time.time() for i in range(k): outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() - print('row-wise', time.time()-t0) - + print("row-wise", time.time() - t0) C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) @@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner): for i in range(k): outC32, SC = F.igemmlt(A2, B2, SA, SB) torch.cuda.synchronize() - print('vector-wise', time.time()-t0) - - + print("vector-wise", time.time() - t0) n = 2 -dim1 = torch.randint(2,1024, size=(n,)).tolist() -dim2 = torch.randint(2,1024, size=(n,)).tolist() -#dim1 = [8*1024] -#dim2 = [4*1024] +dim1 = torch.randint(2, 1024, size=(n,)).tolist() +dim2 = torch.randint(2, 1024, size=(n,)).tolist() +# dim1 = [8*1024] +# dim2 = [4*1024] dim3 = [0] dtype = [torch.int8] -a_order = ['row'] -out_order = ['col32', 'col_turing', 'col_ampere'] +a_order = ["row"] +out_order = ["col32", "col_turing", "col_ampere"] transpose = [False, True] dims = [2] -values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( + *vals + ) + for vals in values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names +) def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): assert S1[0][0] == S2[0][0] assert S1[0][1] == S2[0][1] - #print(out1) - #print(out2) + # print(out1) + # print(out2) torch.testing.assert_allclose(out1, out2) + n = 2 -#dim1 = torch.randint(2,1024, size=(n,)).tolist() -#dim2 = torch.randint(2,1024, size=(n,)).tolist() +# dim1 = torch.randint(2,1024, size=(n,)).tolist() +# dim2 = torch.randint(2,1024, size=(n,)).tolist() dim1 = [1] dim2 = [33] dtype = [torch.int8] -#a_order = ['col_turing', 'col_ampere'] -a_order = ['col_turing'] -out_order = ['row'] -values = list(product(dim1,dim2,dtype, a_order, out_order)) -names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values] +# a_order = ['col_turing', 'col_ampere'] +a_order = ["col_turing"] +out_order = ["row"] +values = list(product(dim1, dim2, dtype, a_order, out_order)) +names = [ + "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): for i in range(1): - A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype) + A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype) out2, S2 = F.transform(A, to_order=orderA) - A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2) + A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2) assert A2.shape[0] == A.shape[0] assert A2.shape[1] == A.shape[1] - - print('') + print("") print(A) print(out2) print(A2) - - #torch.testing.assert_allclose(A, A2) - - + # torch.testing.assert_allclose(A, A2) def test_overflow(): formatB = F.get_special_format_str() print(formatB) for i in range(2): - a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) - b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) + a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) + b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) - Ca, Sa = F.nvidia_transform(a, 'col32') + Ca, Sa = F.nvidia_transform(a, "col32") Cb, Sb = F.nvidia_transform(b, formatB) c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) @@ -1198,46 +1315,51 @@ def test_overflow(): n = 2 -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -#dim1 = [4] -#dim2 = [5] +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +# dim1 = [4] +# dim2 = [5] + +values = list(product(dim1, dim2)) +names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] + -values = list(product(dim1,dim2)) -names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) def test_coo_double_quant(dim1, dim2): threshold = 3.00 for i in range(k): - A = torch.randn(dim1, dim2, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() - idx = (torch.abs(A) >= threshold) + idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) if coo_tensor is not None: - A1 = A*idx + A1 = A * idx A2 = torch.zeros_like(A) A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values torch.testing.assert_allclose(A1, A2) - A1 = A*(idx==0) - A2 = (CA.float()*statsA.unsqueeze(1)/127).half() - torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2) + A1 = A * (idx == 0) + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + n = 2 -dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -dim2 = torch.randint(1,1*1024, size=(n,)).tolist() -#dim1 = [7] -#dim2 = [11] +dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() +# dim1 = [7] +# dim2 = [11] transposed_B = [False, True] -values = list(product(dim1,dim2, transposed_B)) -names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, transposed_B)) +names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) def test_spmm_coo(dim1, dim2, transposed_B): threshold = 1.5 dim3 = torch.randint(32, 128, size=(1,)).item() - #dim3 = 17 + # dim3 = 17 for i in range(k): A = torch.randn(dim1, dim2).cuda().half() if transposed_B: @@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx if transposed_B: out2 = F.spmm_coo(cooA, B.t()) @@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B): assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) - def test_spmm_bench(): batch = 2 - model = 1024*1 - hidden = model*4 + model = 1024 * 1 + hidden = model * 4 seq = 1024 - dim1 = batch*seq + dim1 = batch * seq dim2 = model dim3 = hidden threshold = 4 - A = torch.randn(dim1, dim2, device='cuda').half() - B = torch.randn(dim2, dim3, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.randn(dim2, dim3, device="cuda").half() for i in range(10): C1 = bnb.matmul(A, B) @@ -1282,14 +1405,16 @@ def test_spmm_bench(): for i in range(k): C1 = bnb.matmul(A, B) torch.cuda.synchronize() - t8 = time.time()-t0 + t8 = time.time() - t0 idx = torch.abs(A) >= threshold nnz = (idx == 1).sum().item() - print(nnz/idx.numel()) + print(nnz / idx.numel()) rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) for i in range(10): out2 = F.spmm_coo(cooA, B) @@ -1299,20 +1424,22 @@ def test_spmm_bench(): for i in range(k): out2 = F.spmm_coo(cooA, B) torch.cuda.synchronize() - tsp = time.time()-t0 + tsp = time.time() - t0 print(tsp, t8) - print(tsp/t8) + print(tsp / t8) n = 2 -dim1 = torch.randint(256,1*1024, size=(n,)).tolist() -dim2 = torch.randint(256,1*1024, size=(n,)).tolist() -values = list(product(dim1,dim2)) -names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() +dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() +values = list(product(dim1, dim2)) +names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2", values, ids=names) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 - formatB = 'col_turing' + formatB = "col_turing" for i in range(k): A = torch.randn(dim1, dim2).cuda().half() w1 = torch.randn(dim1, dim2).cuda().half() @@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2): CTw1, Sw1 = F.transform(Cw1, formatB) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) @@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2): out4 = F.spmm_coo(coo_tensor, w1.t()) out5 = out3 + out4 - err1 = torch.abs(out1-out2).mean().item() - err2 = torch.abs(out1-out5).mean().item() + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out5).mean().item() assert err2 < err1 @@ -1350,91 +1477,95 @@ def test_matmuls(): c2 = bnb.matmul(a, b) c3 = bnb.matmul(a, b) - err1 = torch.abs(c1-c2).mean().item() - err2 = torch.abs(c1-c3).mean().item() + err1 = torch.abs(c1 - c2).mean().item() + err2 = torch.abs(c1 - c3).mean().item() assert err1 < 0.2 assert err2 < 0.2 - n = 2 -#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1*2048] +# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 2048] dim2 = [12288] -#dim1 = [32] -#dim2 = [32] -#dtype = [torch.float16, torch.int8] +# dim1 = [32] +# dim2 = [32] +# dtype = [torch.float16, torch.int8] dtype = [torch.float16] -out_function = ['zeros', 'ones'] -values = list(product(dim1,dim2, dtype, out_function)) -names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values] +out_function = ["zeros", "ones"] +values = list(product(dim1, dim2, dtype, out_function)) +names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): out_func = getattr(torch, out_func) threshold = 3.3 - #threshold = 2.8 - #threshold = 0.0 - A = torch.randn(dim1, dim2, device='cuda').half() + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() if dtype == torch.float16: - B = torch.randn(dim2, dim2*4, device='cuda').half() + B = torch.randn(dim2, dim2 * 4, device="cuda").half() torch.nn.init.xavier_uniform_(B) else: - B = torch.randn(dim2, dim2*4, device='cuda').half() + B = torch.randn(dim2, dim2 * 4, device="cuda").half() torch.nn.init.xavier_uniform_(B) - B, SB = F.vectorwise_quant(B, quant_type='linear') - #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + B, SB = F.vectorwise_quant(B, quant_type="linear") + # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) - print('') + print("") idx = torch.abs(A) >= threshold nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx out1 = torch.matmul(A2.half(), B.half()) out = out_func(out1.shape, dtype=torch.float16, device=out1.device) out1 += out.clone() out2 = F.spmm_coo_very_sparse(cooA, B, out=out) - #print(B) - #print(out1) - #print(out2) - p = 200/(2048*12288*4) + # print(B) + # print(out1) + # print(out2) + p = 200 / (2048 * 12288 * 4) n = out1.numel() - count = math.ceil(p*n) + count = math.ceil(p * n) std = out1.std() out1 /= std out2 /= std assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) - #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) + # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) - #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) + # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) - #Bt = torch.randn(dim2*4, dim2, device='cuda').half() - #torch.cuda.synchronize() - #t0 = time.time() - #print(A2.shape, B.shape) - #for i in range(100): + # Bt = torch.randn(dim2*4, dim2, device='cuda').half() + # torch.cuda.synchronize() + # t0 = time.time() + # print(A2.shape, B.shape) + # for i in range(100): # #out3 = F.spmm_coo(cooA, Bt.t()) # #out2 = F.spmm_coo(cooA, B) # #out2 = F.spmm_coo_very_sparse(cooA, B) # #out1 = torch.matmul(A, Bt.t()) - #torch.cuda.synchronize() - #print(time.time() - t0) + # torch.cuda.synchronize() + # print(time.time() - t0) + def test_layout(): - a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16) - a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte() - a2, s2 = F.transform(a1, 'col_turing') + a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16) + a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte() + a2, s2 = F.transform(a1, "col_turing") print(a2.shape) - print(a1.flatten()[8*64:8*64+32]) + print(a1.flatten()[8 * 64 : 8 * 64 + 32]) for i in range(4): - print(a2.flatten()[i*8*32:i*8*32+32], 0) + print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0) def test_coo2csr(): @@ -1444,14 +1575,16 @@ def test_coo2csr(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx csrA = F.coo2csr(cooA) counts = csrA.rowptr[1:] - csrA.rowptr[:-1] assert counts.numel() == A.shape[0] - torch.testing.assert_allclose(counts, (A2!=0).sum(1)) - idx = (A2!=0) + torch.testing.assert_allclose(counts, (A2 != 0).sum(1)) + idx = A2 != 0 torch.testing.assert_allclose(A2[idx], csrA.values) @@ -1462,41 +1595,43 @@ def test_coo2csc(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx cscA = F.coo2csc(cooA) counts = cscA.colptr[1:] - cscA.colptr[:-1] assert counts.numel() == A.shape[1] - torch.testing.assert_allclose(counts, (A2!=0).sum(0)) + torch.testing.assert_allclose(counts, (A2 != 0).sum(0)) # torch uses row-major -> use transpose to transfer to col-major - idx = (A2.t()!=0) + idx = A2.t() != 0 torch.testing.assert_allclose(A2.t()[idx], cscA.values) - n = 2 -#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1*2048] -#dim2 = [12288] +# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 2048] +# dim2 = [12288] dim2 = [2048] -#dim1 = [2] -#dim2 = [2] +# dim1 = [2] +# dim2 = [2] dtype = [torch.int8] -values = list(product(dim1,dim2, dtype)) -names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, dtype)) +names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) def test_spmm_coo_dequant(dim1, dim2, dtype): threshold = 6.0 - #threshold = 2.8 - #threshold = 0.0 - A = torch.randn(dim1, dim2, device='cuda').half() - B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16) + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16) torch.nn.init.xavier_uniform_(B) Bt = B.t().contiguous() - CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) rowidx = torch.randint(0, A.shape[-1], size=(15,)) @@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out1 = torch.matmul(A2, B.half()) out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) - out3 = out3*statsBt.half()/127 + out3 = out3 * statsBt.half() / 127 values, counts = torch.unique(cooA.rowidx, return_counts=True) offset = counts.cumsum(0).int() @@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) - p = 200/(2048*12288*4) + p = 200 / (2048 * 12288 * 4) n = out1.numel() - count = math.ceil(p*n) + count = math.ceil(p * n) assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) - - - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(100): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(100): # out2 = F.spmm_coo_very_sparse(cooA, B) - #torch.cuda.synchronize() - #print('fp16', time.time() - t0) + # torch.cuda.synchronize() + # print('fp16', time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = F.spmm_coo(cooA, B) + out2 = F.spmm_coo(cooA, B) torch.cuda.synchronize() - print('cusparse fp16', time.time() - t0) + print("cusparse fp16", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt) + out2 = F.spmm_coo_very_sparse(cooA, CBt) torch.cuda.synchronize() - print('int8', time.time() - t0) + print("int8", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) torch.cuda.synchronize() - print('int8+dequant', time.time() - t0) + print("int8+dequant", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = torch.matmul(A, B) + out2 = torch.matmul(A, B) torch.cuda.synchronize() - print('matmul', time.time() - t0) + print("matmul", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): out1 = bnb.matmul(A, Bt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - out = out1+out2 + out = out1 + out2 torch.cuda.synchronize() - print('sparse+ matmul', time.time() - t0) + print("sparse+ matmul", time.time() - t0) torch.cuda.synchronize() t0 = time.time() @@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): out1 = bnb.matmul(A, Bt) torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) torch.cuda.synchronize() - print('partial matmul', time.time() - t0) + print("partial matmul", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): out1 = bnb.matmul(A, Bt) torch.cuda.synchronize() - print('partial matmul', time.time() - t0) + print("partial matmul", time.time() - t0) + batch_size = 1 seqdim = 2048 values = [] -values.append((batch_size, seqdim, 768, 4*768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) -names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] +values.append((batch_size, seqdim, 768, 4 * 768)) +# values.append((batch_size, seqdim, 1024, 4*1024)) +# values.append((batch_size, seqdim, 1536, 4*1536)) +# values.append((batch_size, seqdim, 2048, 4*2048)) +# values.append((batch_size, seqdim, 2560, 4*2560)) +# values.append((batch_size, seqdim, 4096, 4*4096)) +# values.append((batch_size, seqdim, 5140, 4*5140)) +# values.append((batch_size, seqdim, 12288, 4*12288)) +names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): formatB = F.get_special_format_str() - A = torch.randn(batch, seq, model, device='cuda').half() - B = torch.empty(hidden, model, dtype=torch.float16, device='cuda') + A = torch.randn(batch, seq, model, device="cuda").half() + B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") torch.nn.init.xavier_uniform_(B) linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() @@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden): outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + linearMixedBit = ( + bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + ) linearMixedBit.eval() # warmup for i in range(100): torch.matmul(A, B.t()) torch.cuda.synchronize() - print('') + print("") torch.cuda.synchronize() t0 = time.time() for i in range(100): torch.matmul(A, B.t()) torch.cuda.synchronize() - print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) torch.cuda.synchronize() t0 = time.time() for i in range(100): bnb.matmul(A, B) torch.cuda.synchronize() - print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) CxB, SB = F.transform(CB, to_order=formatB) torch.cuda.synchronize() @@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) torch.cuda.synchronize() - print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) BA, statsB = F.vectorwise_quant(B, dim=1) CxB, SB = F.nvidia_transform(CB, to_order=formatB) @@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): A2 = A.view(-1, A.shape[-1]).contiguous() CA, statsA = F.vectorwise_quant(A2, dim=1) - C32A, SA = F.nvidia_transform(CA, 'col32') + C32A, SA = F.nvidia_transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) + Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) torch.cuda.synchronize() - print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) - BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear') + BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") CxB, SB = F.nvidia_transform(CB, to_order=formatB) torch.cuda.synchronize() t0 = time.time() for i in range(100): A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear') - C32A, SA = F.nvidia_transform(CA, 'col32') + CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + C32A, SA = F.nvidia_transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) - out = Cout*statsB*statsA*(1.0/(127*127)) + Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + out = Cout * statsB * statsA * (1.0 / (127 * 127)) torch.cuda.synchronize() - print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) linear8bit(A) torch.cuda.synchronize() @@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): linear8bit(A) torch.cuda.synchronize() - print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') - + print( + f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) linearMixedBit(A) torch.cuda.synchronize() @@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): linearMixedBit(A) torch.cuda.synchronize() - print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) def test_zeropoint(): def min_max(x): maxA = torch.amax(x, dim=1, keepdim=True) minA = torch.amin(x, dim=1, keepdim=True) - midpoint = (maxA-minA)/2.0 - dyna = 252/(maxA-minA) - #dyna *= 0.98 - x = dyna*x - x = x - torch.round((dyna*(minA+midpoint))) + midpoint = (maxA - minA) / 2.0 + dyna = 252 / (maxA - minA) + # dyna *= 0.98 + x = dyna * x + x = x - torch.round((dyna * (minA + midpoint))) return x.to(torch.int8), minA, midpoint, dyna + batch = 2 seq = 2 model = 4 - hidden = 2*model - #batch = 4 - #seq = 2048 - #model = 1024 - #hidden = 8*model - A = torch.randn(batch*seq, model, device='cuda').half()-0.4 - B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half()) - - #A[0] = 0 - #B[:, 0] = 0 - #A = A*(A>0) - #A[0, 0] = 0 - #A[0, 0] = 6.0 + hidden = 2 * model + # batch = 4 + # seq = 2048 + # model = 1024 + # hidden = 8*model + A = torch.randn(batch * seq, model, device="cuda").half() - 0.4 + B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half()) + + # A[0] = 0 + # B[:, 0] = 0 + # A = A*(A>0) + # A[0, 0] = 0 + # A[0, 0] = 6.0 Ac, minA, midpoint, dyna = min_max(A) - #print(Ac[0, 0], 'zero') - #print(Ac, Ac.min(), Ac.max()) - Bc, maxB = F.vectorwise_quant(B, quant_type='linear') + # print(Ac[0, 0], 'zero') + # print(Ac, Ac.min(), Ac.max()) + Bc, maxB = F.vectorwise_quant(B, quant_type="linear") out = F.igemm(Ac, Bc) - out2 = torch.matmul(A,B) - offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna + out2 = torch.matmul(A, B) + offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna out = out.float() - #print(out.shape, maxB.shape, scale.shape, offset.shape) - norm1 = maxB/127 - C4 = (out/dyna)*norm1+offset - + # print(out.shape, maxB.shape, scale.shape, offset.shape) + norm1 = maxB / 127 + C4 = (out / dyna) * norm1 + offset B1 = torch.nn.Parameter(B.clone()) B2 = torch.nn.Parameter(B.clone()) B3 = torch.nn.Parameter(B.clone()) B4 = torch.nn.Parameter(B.clone()) - C1 = torch.matmul(A, B1) - C2 = bnb.matmul_cublas(A, B2, None, 'linear') - C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint') - C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint') + C2 = bnb.matmul_cublas(A, B2, None, "linear") + C3 = bnb.matmul_cublas(A, B3, None, "zeropoint") + C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint") - err1 = torch.abs(C1-C2).mean().item() - err2 = torch.abs(C1-C3).mean().item() - err3 = torch.abs(C1-C4).mean().item() + err1 = torch.abs(C1 - C2).mean().item() + err2 = torch.abs(C1 - C3).mean().item() + err3 = torch.abs(C1 - C4).mean().item() print(err1, err2, err3) - #assert err1 > err2 + # assert err1 > err2 loss1 = C1.mean() loss2 = C2.mean() @@ -1765,40 +1917,38 @@ def test_zeropoint(): print(B2.grad) print(B3.grad) print(B4.grad) - err1 = torch.abs(B1.grad-B2.grad).mean().item() - err2 = torch.abs(B1.grad-B3.grad).mean().item() - err3 = torch.abs(B1.grad-B4.grad).mean().item() + err1 = torch.abs(B1.grad - B2.grad).mean().item() + err2 = torch.abs(B1.grad - B3.grad).mean().item() + err3 = torch.abs(B1.grad - B4.grad).mean().item() print(err1, err2, err3) - - def test_zp(): def quant_zp(x): dtype = x.dtype x = x.float() dyna = x.max() - x.min() - if dyna == 0: dyna = 1 - qx = 254./dyna + if dyna == 0: + dyna = 1 + qx = 254.0 / dyna minx = x.min() - #zpx = torch.round(minx* qx) - #zpx = 127 - torch.round(x.max()* qx) - zpx = torch.round(x.min()* qx) - 127 - x = (qx*x) + zpx + # zpx = torch.round(minx* qx) + # zpx = 127 - torch.round(x.max()* qx) + zpx = torch.round(x.min() * qx) - 127 + x = (qx * x) + zpx return x, qx, zpx + batch = 2 seq = 512 model = 1024 - hidden = 4*model - A = torch.randn(batch*seq, model, device='cuda').half()*0.1 - B = torch.randn(model, hidden, device='cuda').half()*0.1 - + hidden = 4 * model + A = torch.randn(batch * seq, model, device="cuda").half() * 0.1 + B = torch.randn(model, hidden, device="cuda").half() * 0.1 C0 = torch.matmul(A, B) - - #A, SA = F.vectorwise_quant(A, quant_type='linear') - #B, SB = F.vectorwise_quant(B, quant_type='linear') + # A, SA = F.vectorwise_quant(A, quant_type='linear') + # B, SB = F.vectorwise_quant(B, quant_type='linear') A = A.float() B = B.float() @@ -1806,69 +1956,68 @@ def test_zp(): C3 = bnb.matmul(A.half(), B.t().contiguous().half()) zp = 1 - #C2 = torch.matmul(A-zp, B) - #C2 += B.sum(0).view(1, -1)*zp - C2 = torch.matmul(A, B-zp) - C2 -= A.sum(1).view(-1, 1)*zp + # C2 = torch.matmul(A-zp, B) + # C2 += B.sum(0).view(1, -1)*zp + C2 = torch.matmul(A, B - zp) + C2 -= A.sum(1).view(-1, 1) * zp ca, cqa, cza = quant_zp(A) print(ca.min(), ca.max()) - print((ca-cza).min(), (ca-cza).max()) + print((ca - cza).min(), (ca - cza).max()) zp = 1 scale = 2.0 - C5 = torch.matmul((A*scale)-zp, B) - C5 += B.sum(0)*zp + C5 = torch.matmul((A * scale) - zp, B) + C5 += B.sum(0) * zp C5 /= scale CA, qa, zpa = quant_zp(A) C4 = torch.matmul(CA, B) - C4 -= B.sum(0)*zpa + C4 -= B.sum(0) * zpa C4 /= qa zpb = 1 zpa = 1 qa = 2 qb = 2 - C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb) - C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) - C6 -= zpa*zpb*A.shape[1] - C6 /= qa*qb + C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb) + C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) + C6 -= zpa * zpb * A.shape[1] + C6 /= qa * qb CA, qa, zpa = quant_zp(A) CB, qb, zpb = quant_zp(B) C7 = torch.matmul(CA, CB) - C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) - C7 -= zpa*zpb*A.shape[1] - C7 /= qa*qb + C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) + C7 -= zpa * zpb * A.shape[1] + C7 /= qa * qb - print('') - #print(C0.flatten()[:10]) + print("") + # print(C0.flatten()[:10]) print(C1.flatten()[:10]) print(C2.flatten()[:10]) print(C3.flatten()[:10]) print(C5.flatten()[:10]) print(C6.flatten()[:10]) print(C7.flatten()[:10]) - err1 = torch.abs(C1-C2).mean().item() - err2 = torch.abs(C1-C3).mean().item() - err3 = torch.abs(C1-C4).mean().item() - err4 = torch.abs(C1-C5).mean().item() - err5 = torch.abs(C1-C6).mean().item() - err6 = torch.abs(C1-C7).mean().item() + err1 = torch.abs(C1 - C2).mean().item() + err2 = torch.abs(C1 - C3).mean().item() + err3 = torch.abs(C1 - C4).mean().item() + err4 = torch.abs(C1 - C5).mean().item() + err5 = torch.abs(C1 - C6).mean().item() + err6 = torch.abs(C1 - C7).mean().item() print(err1, err2, err3, err4, err5, err6) - def test_extract_outliers(): for i in range(k): - shapeA = (4096, 4096*4) + shapeA = (4096, 4096 * 4) idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - #idx = torch.Tensor([0]).int().cuda() - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + # idx = torch.Tensor([0]).int().cuda() + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) outliers1 = A[:, idx.long()] - CA, SA = F.transform(A, 'col_turing') + CA, SA = F.transform(A, "col_turing") outliers2 = F.extract_outliers(CA, SA, idx) @@ -1877,7 +2026,7 @@ def test_extract_outliers(): torch.testing.assert_allclose(outliers1, outliers2) - CA, SA = F.transform(A, 'col_ampere') + CA, SA = F.transform(A, "col_ampere") outliers2 = F.extract_outliers(CA, SA, idx) -- cgit v1.2.3