summaryrefslogtreecommitdiff
path: root/tests/test_functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r--tests/test_functional.py1574
1 files changed, 906 insertions, 668 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index bfc3e28..ab7d672 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,25 +1,29 @@
-import pytest
import math
import random
import time
-import torch
-import bitsandbytes as bnb
-import einops
-
from itertools import product
+import einops
+import pytest
+import torch
+
+import bitsandbytes as bnb
from bitsandbytes import functional as F
-torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
+torch.set_printoptions(
+ precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
+)
k = 20
+
def assert_all_approx_close(a, b, rtol, atol, count):
idx = torch.isclose(a, b, rtol, atol)
- sumval = (idx==0).sum().item()
+ sumval = (idx == 0).sum().item()
if sumval > count:
- print(f'Too many values not close: assert {sumval} < {count}')
+ print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol)
+
class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__()
@@ -35,13 +39,14 @@ class FFN(torch.nn.Module):
x = self.fc2(x)
return x
+
class Timer(object):
def __init__(self):
self.starts = {}
self.ends = {}
self.agg = {}
- def tick(self, name='default'):
+ def tick(self, name="default"):
if name not in self.starts:
self.starts[name] = torch.cuda.Event(enable_timing=True)
self.ends[name] = torch.cuda.Event(enable_timing=True)
@@ -49,66 +54,72 @@ class Timer(object):
else:
ms = self.tock(name, evict=True, print_ms=False)
- def tock(self, name='default', evict=True, print_ms=True):
+ def tock(self, name="default", evict=True, print_ms=True):
if name in self.ends:
self.ends[name].record()
torch.cuda.synchronize()
ms = self.starts[name].elapsed_time(self.ends[name])
- if name not in self.agg: self.agg[name] = 0.0
+ if name not in self.agg:
+ self.agg[name] = 0.0
self.agg[name] += ms
if evict:
self.starts.pop(name)
self.ends.pop(name)
if print_ms and name in self.agg:
- print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0))
+ print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
return self.agg[name]
def reset(self):
- self.starts = {}
+ self.starts = {}
self.ends = {}
self.agg = {}
- print('Resetting benchmark data')
+ print("Resetting benchmark data")
+
def setup():
pass
+
def teardown():
pass
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
+
+@pytest.mark.parametrize(
+ "dtype", [torch.float32, torch.float16], ids=["float", "half"]
+)
def test_estimate_quantiles(dtype):
- A = torch.rand(1024, 1024, device='cuda')
+ A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
- percs = torch.linspace(1/512, 511/512, 256, device=A.device)
+ percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
- A = torch.randn(1024, 1024, device='cuda')
+ A = torch.randn(1024, 1024, device="cuda")
A = A.to(dtype)
code = F.estimate_quantiles(A)
quantiles = torch.quantile(A.float(), percs)
- diff = torch.abs(code-quantiles)
+ diff = torch.abs(code - quantiles)
assert (diff > 5e-02).sum().item() == 0
def test_quantile_quantization():
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0075
- A1 = torch.rand(1024, 1024, device='cuda')
+ A1 = torch.rand(1024, 1024, device="cuda")
code = F.estimate_quantiles(A1)
C = F.quantize_no_absmax(A1, code)
A2 = F.dequantize_no_absmax(C, code)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
assert diff < 0.001
@@ -117,22 +128,22 @@ def test_dynamic_quantization():
diffs = []
reldiffs = []
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
- diff = torch.abs(A1-A2)
- reldiff = diff/torch.abs(A1+1e-8)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
- #print(sum(diffs)/len(diffs))
- #print(sum(reldiffs)/len(reldiffs))
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))
for i in range(100):
- A1 = torch.rand(1024, 1024, device='cuda')
+ A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize(A1)
A2 = F.dequantize(C, S)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
assert diff < 0.004
@@ -141,56 +152,62 @@ def test_dynamic_blockwise_quantization():
diffs = []
reldiffs = []
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2)
- reldiff = diff/torch.abs(A1+1e-8)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
- #print(sum(diffs)/len(diffs))
- #print(sum(reldiffs)/len(reldiffs))
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
- A1 = torch.rand(1024, 1024, device='cuda')
+ A1 = torch.rand(1024, 1024, device="cuda")
C, S = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2).mean().item()
+ diff = torch.abs(A1 - A2).mean().item()
assert diff < 0.0033
diffs.append(diff)
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
- #print(sum(diffs)/len(diffs))
+ # print(sum(diffs)/len(diffs))
+
def test_dynamic_blockwise_stochastic_quantization():
diffs = []
reldiffs = []
rand = torch.rand(1024).cuda()
for i in range(100):
- A1 = torch.randn(1024, 1024, device='cuda')
+ A1 = torch.randn(1024, 1024, device="cuda")
C1, S1 = F.quantize_blockwise(A1, rand=rand)
C2, S2 = F.quantize_blockwise(A1)
# a maximunm distance of quantized values of 1
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
- fraction_smaller = (C1<C2).float().sum()/C1.numel()
- fraction_larger = (C1>C2).float().sum()/C1.numel()
- torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
+ fraction_smaller = (C1 < C2).float().sum() / C1.numel()
+ fraction_larger = (C1 > C2).float().sum() / C1.numel()
+ torch.testing.assert_allclose(
+ fraction_larger, fraction_smaller, atol=0.01, rtol=0
+ )
-
-@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
+@pytest.mark.parametrize(
+ "gtype", [torch.float32, torch.float16], ids=["float", "half"]
+)
def test_percentile_clipping(gtype):
- gnorm_vec1 = torch.zeros(100, device='cuda')
- gnorm_vec2 = torch.zeros(100, device='cuda')
+ gnorm_vec1 = torch.zeros(100, device="cuda")
+ gnorm_vec2 = torch.zeros(100, device="cuda")
n = 4
step = 0
- percentile=5
+ percentile = 5
for i in range(k):
step += 1
- g = torch.randn(n, n, dtype=gtype, device='cuda')
- gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
- assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1
+ g = torch.randn(n, n, dtype=gtype, device="cuda")
+ gnorm1, clip2, gnorm_scale = F.percentile_clipping(
+ g, gnorm_vec2, step, percentile=percentile
+ )
+ assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1
gnorm2 = torch.norm(g.float())
if step == 1:
@@ -208,74 +225,98 @@ def test_percentile_clipping(gtype):
def quant(x):
max1 = torch.abs(x).max()
- x = torch.round(x/max1*127)
+ x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
+
def dequant(c, maxC):
- return c.float()*(maxC/127)
+ return c.float() * (maxC / 127)
+
def mm_dequant(maxA, maxB, C):
- return C.float()*(maxA/127)*(maxB/127)
+ return C.float() * (maxA / 127) * (maxB / 127)
+
def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
- max1[max1==0] = 1.0
- x = torch.round(x/max1*127)
+ max1[max1 == 0] = 1.0
+ x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
+
def quant_multi_chunk(x, dim, chunk_size=32):
- if dim==1:
- x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size)
- max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True)
+ if dim == 1:
+ x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
+ max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape)
- elif dim==0:
- x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size)
+ elif dim == 0:
+ x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape)
- max1[max1==0] = 1.0
- x = torch.round(x/max1*127)
+ max1[max1 == 0] = 1.0
+ x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)
+
def quant_minmax(A):
minA = A.min()
maxA = A.max()
-def mean(xx):
- return sum(xx)/float(len(xx))
-#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
-#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
-dim1 = [1024*2]
-dim2 = [1024*16]
-methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)]
+def mean(xx):
+ return sum(xx) / float(len(xx))
+
+
+# dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
+# dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
+dim1 = [1024 * 2]
+dim2 = [1024 * 16]
+methods = [
+ (
+ lambda x, dim: quant(x),
+ lambda x, dim: quant(x),
+ dequant,
+ dequant,
+ mm_dequant,
+ )
+]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
-#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
-method_names = ['linear', 'vectorwise']
+# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
+method_names = ["linear", "vectorwise"]
batched = [False, True]
-values = list(product(dim1,dim2, methods, batched))
-values_names = list(product(dim1,dim2, method_names, batched))
-names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names]
-@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
+values = list(product(dim1, dim2, methods, batched))
+values_names = list(product(dim1, dim2, method_names, batched))
+names = [
+ "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
+ for vals in values_names
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, quant_methods, batched", values, ids=names
+)
def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
- print('')
+ print("")
for i in range(5):
if batched:
- A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda')
- B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda')
+ A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
+ B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1)
else:
- A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda')
- B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda')
+ A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
+ B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
- torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
+ torch.testing.assert_allclose(
+ quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
+ )
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
@@ -284,43 +325,53 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C)
std = out2.std()
- out/= std
- out2/= std
- err = torch.abs(out-out2)
- relerr = err/torch.abs(out2)
+ out /= std
+ out2 /= std
+ err = torch.abs(out - out2)
+ relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
print(mean(errors))
print(mean(relerrors))
-
-
-
-
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
-
n = 2
-hidden_dim = torch.randint(32,256, size=(n,)).tolist()
-batch_dim = torch.randint(16,256, size=(n,)).tolist()
-seq_dim = torch.randint(16,256, size=(n,)).tolist()
+hidden_dim = torch.randint(32, 256, size=(n,)).tolist()
+batch_dim = torch.randint(16, 256, size=(n,)).tolist()
+seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)]
-values = list(product(hidden_dim,batch_dim, transpose, seq_dim))
-names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
+values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
+names = [
+ "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
+)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
- shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
- shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ shapeA = (
+ (batch_dim, hidden_dim)
+ if not transpose[0]
+ else (hidden_dim, batch_dim)
+ )
+ shapeB = (
+ (32 * random.randint(1, 4), hidden_dim)
+ if transpose[1]
+ else (hidden_dim, 32 * random.randint(1, 4))
+ )
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
@@ -338,9 +389,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
- shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ shapeB = (
+ (32 * random.randint(1, 4), hidden_dim)
+ if transpose[1]
+ else (hidden_dim, 32 * random.randint(1, 4))
+ )
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
@@ -352,40 +407,57 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
n = 3
-seq_dim = torch.randint(32,512, size=(n,)).tolist()
-hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
-batch_dim = torch.randint(2,16, size=(n,)).tolist()
-values = list(product(seq_dim,hidden_dim,batch_dim))
-names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values]
+seq_dim = torch.randint(32, 512, size=(n,)).tolist()
+hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
+batch_dim = torch.randint(2, 16, size=(n,)).tolist()
+values = list(product(seq_dim, hidden_dim, batch_dim))
+names = [
+ "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
- A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8)
- out2 = torch.einsum('bsi, bso->io', A.float(), B.float())
- iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
+ A = torch.randint(
+ -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
+ ).to(torch.int8)
+ B = torch.randint(
+ -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
+ ).to(torch.int8)
+ out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
+ iout = torch.empty(
+ A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
+ )
out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2)
+
n = 2
-seq_dim = torch.randint(32,512, size=(n,)).tolist()
-hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
-batch_dim = torch.randint(2,16, size=(n,)).tolist()
+seq_dim = torch.randint(32, 512, size=(n,)).tolist()
+hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
+batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True]
-values = list(product(seq_dim,hidden_dim,batch_dim, transpose))
-names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
-def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
+values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
+names = [
+ "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
+ for vals in values
+]
+
+@pytest.mark.parametrize(
+ "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
+)
+def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
- scale = (maxA-minA)/2.0
- return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale
+ scale = (maxA - minA) / 2.0
+ return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale
seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16)
@@ -395,30 +467,32 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = []
relerrs2 = []
for i in range(k):
- A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda')
+ A = torch.normal(
+ 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
+ )
if transpose:
- B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda')
+ B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
- B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda')
+ B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A)
if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t())
- out2 = torch.matmul(A,B.t())
- offset = B.t().sum(0)*(minA+scale)
+ out2 = torch.matmul(A, B.t())
+ offset = B.t().sum(0) * (minA + scale)
out = out.float()
- out = (out*maxB.t()*scale/(127*127))+offset
+ out = (out * maxB.t() * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
else:
maxB, Bc = quant_multi(B, dim=0)
- offset = B.sum(0)*(minA+scale)
+ offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc)
- out2 = torch.matmul(A,B)
+ out2 = torch.matmul(A, B)
out = out.float()
- out = (out*maxB*scale/(127*127))+offset
+ out = (out * maxB * scale / (127 * 127)) + offset
maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc)
@@ -429,31 +503,37 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
out /= std
out3 /= std
- err = torch.abs(out-out2)
- relerr = err/(torch.abs(out2)+1e-7)
+ err = torch.abs(out - out2)
+ relerr = err / (torch.abs(out2) + 1e-7)
- err2 = torch.abs(out3-out2)
- relerr2 = err2/(torch.abs(out2)+1e-7)
+ err2 = torch.abs(out3 - out2)
+ relerr2 = err2 / (torch.abs(out2) + 1e-7)
errs.append(err.mean().item())
relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item())
relerrs2.append(relerr2.mean().item())
- #print(mean(errs))
- #print(mean(relerrs))
- #print(mean(errs2))
- #print(mean(relerrs2))
+ # print(mean(errs))
+ # print(mean(relerrs))
+ # print(mean(errs2))
+ # print(mean(relerrs2))
assert mean(errs) < 0.015
assert mean(relerrs) < 0.3
+
n = 2
-dim1 = torch.randint(1,64, size=(n,)).tolist()
-dim2 = torch.randint(32,128, size=(n,)).tolist()
-dim3 = torch.randint(32,256, size=(n,)).tolist()
-dim4 = torch.randint(32,256, size=(n,)).tolist()
+dim1 = torch.randint(1, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 128, size=(n,)).tolist()
+dim3 = torch.randint(32, 256, size=(n,)).tolist()
+dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
-values = list(product(dim1,dim2,dim3,dim4,transpose))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, dim3, dim4, transpose))
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
def test_ibmm(dim1, dim2, dim3, dim4, transpose):
dim2 = dim2 - (dim2 % 16)
@@ -462,8 +542,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float())
@@ -475,150 +555,203 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
- out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
+ out2 = torch.bmm(
+ A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
+ )
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float())
+
n = 1
-dim1 = torch.randint(1,64, size=(n,)).tolist()
-dim2 = torch.randint(32,128, size=(n,)).tolist()
-dim3 = torch.randint(32,256, size=(n,)).tolist()
-values = list(product(dim1,dim2,dim3))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values]
+dim1 = torch.randint(1, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 128, size=(n,)).tolist()
+dim3 = torch.randint(32, 256, size=(n,)).tolist()
+values = list(product(dim1, dim2, dim3))
+names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
def test_vector_quant(dim1, dim2, dim3):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
for i in range(k):
- A = torch.randn(size=(dim2, dim3), device='cuda')
+ A = torch.randn(size=(dim2, dim3), device="cuda")
qA, SA = F.vectorwise_quant(A, dim=0)
A1 = F.vectorwise_dequant(qA, SA)
torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)
-
n = 2
-dim1 = torch.randint(2,256, size=(n,)).tolist()
-dim2 = torch.randint(2,256, size=(n,)).tolist()
-dim3 = torch.randint(2,256, size=(n,)).tolist()
-#dim1, dim2 = (256,), (256,)
+dim1 = torch.randint(2, 256, size=(n,)).tolist()
+dim2 = torch.randint(2, 256, size=(n,)).tolist()
+dim3 = torch.randint(2, 256, size=(n,)).tolist()
+# dim1, dim2 = (256,), (256,)
dtype = [torch.int8, torch.int32]
-a_order = ['row']
-out_order = ['col', 'row', 'col32']
+a_order = ["row"]
+out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
-values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
-
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
-def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
- if dims == 3 and out_order != 'col32': return
- if dtype == torch.int32 and out_order != 'col32': return
+values = list(
+ product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
+)
+
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
+ *vals
+ )
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
+ values,
+ ids=names,
+)
+def test_nvidia_transform(
+ dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
+):
+ if dims == 3 and out_order != "col32":
+ return
+ if dtype == torch.int32 and out_order != "col32":
+ return
func = F.get_transform_func(dtype, orderA, orderOut, transpose)
if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype)
+ A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
+ dtype
+ )
out, S = F.nvidia_transform(A, to_order=orderOut)
- if orderOut == 'row':
+ if orderOut == "row":
torch.testing.assert_allclose(A.flatten(), out.flatten())
- elif orderOut == 'col':
+ elif orderOut == "col":
torch.testing.assert_allclose(A.t().flatten(), out.flatten())
- elif orderOut == 'col32':
+ elif orderOut == "col32":
if dims == 2:
- n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32)))
+ n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
- n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32)))
+ n = (
+ A.shape[0]
+ * A.shape[1]
+ * (A.shape[2] + (32 - (A.shape[2] % 32)))
+ )
assert out.numel() == n
- elif orderOut == 'col_turing':
+ elif orderOut == "col_turing":
# 32 col 8 row tiles
- n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32)))
+ n = (A.shape[0] + (8 - A.shape[0] % 8)) * (
+ A.shape[1] + (32 - (A.shape[1] % 32))
+ )
assert out.numel() == n
total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
for row in range(A.shape[0]):
for col in range(A.shape[1]):
- i = row*A.shape[1]
+ i = row * A.shape[1]
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
- rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile
- offset = 32*8*(rowtile+coltile)
+ rowtile = (
+ (row // 8) + (1 if row % 8 != 0 else 0)
+ ) * total_coltile
+ offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
- row2 = (row%8)*32
+ row2 = (row % 8) * 32
+ assert A.flatten()[i + j] == A[row, col]
+ # assert A.flatten()[i+j] == out.flatten()[row2+col2]
+ # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
+ # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
- assert A.flatten()[i+j] == A[row, col]
- #assert A.flatten()[i+j] == out.flatten()[row2+col2]
- #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
- #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
-
- if orderOut == 'col32':
- out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S)
+ if orderOut == "col32":
+ out2, S = F.nvidia_transform(
+ out, from_order=orderOut, to_order="row", state=S
+ )
torch.testing.assert_allclose(A, out2)
n = 1
-dim1 = torch.randint(1,256, size=(n,)).tolist()
-dim2 = torch.randint(32,512, size=(n,)).tolist()
-dim3 = torch.randint(32,1024, size=(n,)).tolist()
-dim4 = torch.randint(32,1024, size=(n,)).tolist()
+dim1 = torch.randint(1, 256, size=(n,)).tolist()
+dim2 = torch.randint(32, 512, size=(n,)).tolist()
+dim3 = torch.randint(32, 1024, size=(n,)).tolist()
+dim4 = torch.randint(32, 1024, size=(n,)).tolist()
-#dim1 = [2]
-#dim2 = [2]
-#dim3 = [2]
-#dim4 = [2]
+# dim1 = [2]
+# dim2 = [2]
+# dim3 = [2]
+# dim4 = [2]
-dims = (2,3)
+dims = (2, 3)
ldb = [0]
-#ldb = list(range(256, 1*1024, 256))
-values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
for i in range(k):
if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
+ torch.int8
+ )
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ A = torch.randint(
+ -128, 127, size=(dim1, dim2, dim3), device="cuda"
+ ).to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.t().float())
- A2, SA = F.transform(A, 'col32')
- B2, SB = F.transform(B, 'col_turing')
+ A2, SA = F.transform(A, "col32")
+ B2, SB = F.transform(B, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
# transpose
- B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.float())
- B2t, SBt = F.transform(B, 'col_turing', transpose=True)
+ B2t, SBt = F.transform(B, "col_turing", transpose=True)
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
torch.testing.assert_allclose(C1, C3.float())
+
dim1 = [32]
dim2 = [32]
dim3 = [32]
dim4 = [32]
dims = (2,)
-#ldb = list(range(256, 1*1024, 256))
-values = list(product(dim1,dim2,dim3,dim4,dims))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1, dim2, dim3, dim4, dims))
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
+ for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
formatB = F.get_special_format_str()
for i in range(k):
if dims == 2:
- A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half()
+ A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
- A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half()
- B = torch.randn((dim4, dim3), device='cuda').half()
+ A = torch.normal(
+ 0, 0.5, size=(dim1, dim2, dim3), device="cuda"
+ ).half()
+ B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
C2 = bnb.matmul(A, B.t())
@@ -627,50 +760,58 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
CxB, SB = F.transform(CB, to_order=formatB)
out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
- #print('')
- #print(output.flatten()[:10])
- #print(C1.flatten()[:10])
- #print(C2.flatten()[:10])
+ # print('')
+ # print(output.flatten()[:10])
+ # print(C1.flatten()[:10])
+ # print(C2.flatten()[:10])
-
- #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
+ # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
# transpose
- #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
- #C1 = torch.matmul(A.float(), B.float())
+ # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(A.float(), B.float())
+
+ # B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
+ # C2, SC = F.igemmlt(A2, B2t, SA, SBt)
+ # C3, S = F.transform(C2, 'row', state=SC)
+ # torch.testing.assert_allclose(C1, C3.float())
- #B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
- #C2, SC = F.igemmlt(A2, B2t, SA, SBt)
- #C3, S = F.transform(C2, 'row', state=SC)
- #torch.testing.assert_allclose(C1, C3.float())
batch_size = 2
seqdim = 512
-#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
-values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+values = [
+ (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024),
+ (batch_size, seqdim, 5120, 3 * 5120),
+ (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024),
+]
+
+
+# values = list(product(batch, seq, model, hidden))
+names = [
+ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
+]
-#values = list(product(batch, seq, model, hidden))
-names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
- A = torch.randn(batch, seq, model, device='cuda').half()
- grad = torch.randn(batch, seq, model, device='cuda').half()
- w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half()
- w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half()
- print('')
+ A = torch.randn(batch, seq, model, device="cuda").half()
+ grad = torch.randn(batch, seq, model, device="cuda").half()
+ w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
+ w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
+ print("")
- #torch.cuda.synchronize()
+ # torch.cuda.synchronize()
## warmup
- #for i in range(100):
+ # for i in range(100):
# torch.matmul(A, w1.t())
- #torch.cuda.synchronize()
+ # torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
@@ -679,77 +820,77 @@ def test_bench_8bit_training(batch, seq, model, hidden):
t0 = time.time()
for i in range(k):
- out1 = torch.matmul(A, w1.t()) # fc1
- #out2 = torch.matmul(out1, w2.t())# fc2
+ out1 = torch.matmul(A, w1.t()) # fc1
+ # out2 = torch.matmul(out1, w2.t())# fc2
- #d1 = torch.matmul(grad, w2) # delta1
- #d2 = torch.matmul(d1, w1) # delta2
+ # d1 = torch.matmul(grad, w2) # delta1
+ # d2 = torch.matmul(d1, w1) # delta2
- #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
- #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
+ # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
+ # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
- #torch.cuda.empty_cache()
+ # torch.cuda.empty_cache()
- #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
- #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+ # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
- #CTw1, Sw1 = F.transform2(Cw1, formatB)
- #CTw2, Sw2 = F.transform2(Cw2, formatB)
- #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
- #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ # CTw1, Sw1 = F.transform2(Cw1, formatB)
+ # CTw2, Sw2 = F.transform2(Cw2, formatB)
+ # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
- #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- #C32A, SA = F.transform2(CA, 'col32')
+ # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ # C32A, SA = F.transform2(CA, 'col32')
## fc1
- #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
+ # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
- #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
- #C32out1, Sout1 = F.transform2(Cout1, 'col32')
- #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
+ # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
+ # C32out1, Sout1 = F.transform2(Cout1, 'col32')
+ # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
- #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
- #C32grad, Sgrad = F.transform2(Cgrad, 'col32')
+ # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
+ # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
- #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
- #C32d1, Sd1 = F.transform2(Cd1, 'col32')
+ # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
+ # C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
- #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
- #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
+ # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
+ # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
- #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
- #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
+ # C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
+ # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
- #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+ # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
- #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
- #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+ # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
- #CTw1, Sw1 = F.transform2(Cw1, formatB)
- #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
- #CTw2, Sw2 = F.transform2(Cw2, formatB)
- #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
- #torch.cuda.synchronize()
- #t0 = time.time()
- #for i in range(k):
+ # CTw1, Sw1 = F.transform2(Cw1, formatB)
+ # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ # CTw2, Sw2 = F.transform2(Cw2, formatB)
+ # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
@@ -802,74 +943,78 @@ def test_bench_8bit_training(batch, seq, model, hidden):
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
- #torch.cuda.synchronize()
- #t8 = time.time() - t0
- #print(t8)
-
-
-
+ # torch.cuda.synchronize()
+ # t8 = time.time() - t0
+ # print(t8)
n = 2
-dim1 = torch.randint(64,256, size=(n,)).tolist()
-dim4 = torch.randint(64,1024, size=(n,)).tolist()
+dim1 = torch.randint(64, 256, size=(n,)).tolist()
+dim4 = torch.randint(64, 1024, size=(n,)).tolist()
-#dim1 = [2*1024]
-#dim4 = [2*1024]
+# dim1 = [2*1024]
+# dim4 = [2*1024]
-#dim1 = [4]
-#dim4 = [4]
+# dim1 = [4]
+# dim4 = [4]
dims = (2,)
-#ldb = list(range(256, 1*1024, 256))
-formatB = ['col_turing', 'col_ampere']
-values = list(product(dim1,dim4,dims, formatB))
-names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+formatB = ["col_turing", "col_ampere"]
+values = list(product(dim1, dim4, dims, formatB))
+names = [
+ "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
def test_dequant_mm(dim1, dim4, dims, formatB):
inner = torch.randint(1, 128, size=(1,)).item()
formatB = F.get_special_format_str()
for i in range(k):
- A = torch.randn(dim1, inner, device='cuda')
- B = torch.randn(dim4, inner, device='cuda')
+ A = torch.randn(dim1, inner, device="cuda")
+ B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
- A2, SA = F.nvidia_transform(A1, 'col32')
+ A2, SA = F.nvidia_transform(A1, "col32")
B2, SB = F.nvidia_transform(B1, formatB)
C2, SC = F.igemmlt(A2, B2, SA, SB)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
n = C1.numel()
p = 0.06
- assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}'
+ assert (
+ count / n < p
+ ), f"error in more than {p} of elements: {count}/{n}={count/n}"
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
torch.testing.assert_allclose(C5, C4)
- #print(C2)
-
+ # print(C2)
n = 2
-dim1 = [1*1024]
-dim2 = [1*1024]
-#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1 * 1024]
+dim2 = [1 * 1024]
+# dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
dims = (2,)
-#ldb = list(range(256, 1*1024, 256))
-values = list(product(dim1,dim2,dims))
-names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values]
+# ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1, dim2, dims))
+names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
def test_colrow_absmax(dim1, dim2, dims):
for i in range(k):
threshold = 3.0
- A = torch.randn(dim1, dim2, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
A_truncated = A.clone()
A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
if dims == 2:
@@ -880,37 +1025,51 @@ def test_colrow_absmax(dim1, dim2, dims):
else:
assert False
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
-
- A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4)
- nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten()
- nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device)
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
+ A, threshold=threshold
+ )
+
+ A_blocked = einops.rearrange(
+ torch.abs(A),
+ "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size",
+ row_tiles=16,
+ block_size=64 * 4,
+ )
+ nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten()
+ nnz_block_ptr1 = torch.zeros(
+ nnz_rows1_counts.shape[0] + 1,
+ dtype=nnz_rows1_counts.dtype,
+ device=nnz_rows1_counts.device,
+ )
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
+ A, threshold=0.0
+ )
torch.testing.assert_allclose(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2)
assert nnz_block_ptr2 is None
-
n = 2
-#dim1 = [8*1024]
-#dim2 = [4*1024]
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+# dim1 = [8*1024]
+# dim2 = [4*1024]
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+
+values = list(product(dim1, dim2))
+names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
+
-values = list(product(dim1,dim2))
-names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_double_quant(dim1, dim2):
for i in range(k):
- A = torch.randn(dim1, dim2, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
out_col1, Scol = F.vectorwise_quant(A, dim=0)
out_row1, Srow = F.vectorwise_quant(A, dim=1)
@@ -920,18 +1079,25 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
-
n = CAt.numel()
- num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item()
- num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item()
+ num_not_close_rows = (
+ (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
+ )
+ num_not_close_cols = (
+ (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
+ )
# allow for 1:500 error due to rounding differences
- min_error = 1/500
- if num_not_close_cols > (min_error*n):
- print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}')
+ min_error = 1 / 500
+ if num_not_close_cols > (min_error * n):
+ print(
+ f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}"
+ )
assert False
- if num_not_close_rows > (min_error*n):
- print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}')
+ if num_not_close_rows > (min_error * n):
+ print(
+ f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}"
+ )
assert False
torch.testing.assert_allclose(Srow.flatten(), statsA)
@@ -939,21 +1105,23 @@ def test_double_quant(dim1, dim2):
n = 4
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
-inner = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim1 = [6]
dim4 = [4]
inner = [8]
values = list(zip(dim1, dim4, inner))
-names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
def test_integrated_igemmlt(dim1, dim4, inner):
for i in range(k):
- A = torch.randn(dim1, inner, device='cuda').half()
- B = torch.randn(dim4, inner, device='cuda').half()
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
out1 = torch.matmul(A.half(), B.t().half())
@@ -967,30 +1135,32 @@ def test_integrated_igemmlt(dim1, dim4, inner):
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
- A2, SA = F.nvidia_transform(C1a, 'col32')
- B2, SB = F.nvidia_transform(C2a, 'col_turing')
+ A2, SA = F.nvidia_transform(C1a, "col32")
+ B2, SB = F.nvidia_transform(C2a, "col_turing")
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
- A2, SA = F.nvidia_transform(A1, 'col32')
- B2, SB = F.nvidia_transform(B1, 'col_turing')
+ A2, SA = F.nvidia_transform(A1, "col32")
+ B2, SB = F.nvidia_transform(B1, "col_turing")
C2, SC = F.igemmlt(A2, B2, SA, SB)
- C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C3, S = F.nvidia_transform(C2, "row", state=SC)
out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
- err1 = torch.abs(out1-out2).mean().item()
- err2 = torch.abs(out1-out3).mean().item()
- assert err2 <= err1*1.01
+ err1 = torch.abs(out1 - out2).mean().item()
+ err2 = torch.abs(out1 - out3).mean().item()
+ assert err2 <= err1 * 1.01
n = 6
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
-inner = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
-names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
@@ -999,79 +1169,81 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
relerr1, relerr2 = [], []
scale = 1
for i in range(k):
- A = torch.randn(dim1, inner, device='cuda').half()
- B = torch.randn(dim4, inner, device='cuda').half()
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
out1 = torch.matmul(A.half(), B.t().half())
-
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
- CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
- A2, SA = F.nvidia_transform(C1a, 'col32')
+ CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
+ A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
- c = 10.0*inner*scale
- row_scale = torch.ones_like(maxA)/c
- outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
- C3, S = F.nvidia_transform(outC32, 'row', state=SC)
+ c = 10.0 * inner * scale
+ row_scale = torch.ones_like(maxA) / c
+ outC32, SC = F.igemmlt(
+ A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
+ )
+ C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
scale = 1.5
else:
- scale = maxval/120
- out3 = C3*maxA*absmaxB*c/(127*127)
+ scale = maxval / 120
+ out3 = C3 * maxA * absmaxB * c / (127 * 127)
C4 = torch.matmul(C1a.float(), CB.float().t())
-
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
outC32, SC = F.igemmlt(A2, B2, SA, SB)
out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
- CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector')
- CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+ CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector")
+ CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear")
C = torch.matmul(CA.float(), CB.t().float())
- out4 = C*SA*SB/(127*127)
- #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
+ out4 = C * SA * SB / (127 * 127)
+ # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
- #print('='*80)
- #print(out1)
- #print(out2)
- #print(out3)
+ # print('='*80)
+ # print(out1)
+ # print(out2)
+ # print(out3)
- #print(out1)
- #print(out2)
- #print(out3)
- err1.append(torch.abs(out1-out2).mean().item())
- err2.append(torch.abs(out1-out3).mean().item())
- err3.append(torch.abs(out1-out4).mean().item())
+ # print(out1)
+ # print(out2)
+ # print(out3)
+ err1.append(torch.abs(out1 - out2).mean().item())
+ err2.append(torch.abs(out1 - out3).mean().item())
+ err3.append(torch.abs(out1 - out4).mean().item())
- #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
- print('')
- print(sum(err1)/len(err1))
- print(sum(err2)/len(err2))
- print(sum(err3)/len(err3))
+ # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
+ print("")
+ print(sum(err1) / len(err1))
+ print(sum(err2) / len(err2))
+ print(sum(err3) / len(err3))
dim1 = [1024, 2048]
-inner = [12288*4, 4096*4]
+inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
-names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
- A = torch.randn(dim1, inner, device='cuda').half()
- B = torch.randn(dim4, inner, device='cuda').half()
+ A = torch.randn(dim1, inner, device="cuda").half()
+ B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
@@ -1082,23 +1254,24 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
- print('16', time.time()-t0)
+ print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
- CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
- A2, SA = F.nvidia_transform(C1a, 'col32')
+ CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
+ A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
- c = 10.0*inner*scale
- row_scale = maxA/c
+ c = 10.0 * inner * scale
+ row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
- outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ outC32, SC = F.igemmlt(
+ A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
+ )
torch.cuda.synchronize()
- print('row-wise', time.time()-t0)
-
+ print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
@@ -1107,32 +1280,47 @@ def test_row_scale_bench(dim1, dim4, inner):
for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB)
torch.cuda.synchronize()
- print('vector-wise', time.time()-t0)
-
-
+ print("vector-wise", time.time() - t0)
n = 2
-dim1 = torch.randint(2,1024, size=(n,)).tolist()
-dim2 = torch.randint(2,1024, size=(n,)).tolist()
-#dim1 = [8*1024]
-#dim2 = [4*1024]
+dim1 = torch.randint(2, 1024, size=(n,)).tolist()
+dim2 = torch.randint(2, 1024, size=(n,)).tolist()
+# dim1 = [8*1024]
+# dim2 = [4*1024]
dim3 = [0]
dtype = [torch.int8]
-a_order = ['row']
-out_order = ['col32', 'col_turing', 'col_ampere']
+a_order = ["row"]
+out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
-values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
+values = list(
+ product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
+)
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
+ *vals
+ )
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
+ values,
+ ids=names,
+)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
- A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype)
+ A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
+ dtype
+ )
elif dims == 3:
- A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+ A = torch.randint(
+ 10, 99, size=(dim1, dim2, dim3), device="cuda"
+ ).to(dtype)
A.view(-1)[-1] = -1
if transpose:
@@ -1144,53 +1332,57 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
assert S1[0][0] == S2[0][0]
assert S1[0][1] == S2[0][1]
- #print(out1)
- #print(out2)
+ # print(out1)
+ # print(out2)
torch.testing.assert_allclose(out1, out2)
+
n = 2
-#dim1 = torch.randint(2,1024, size=(n,)).tolist()
-#dim2 = torch.randint(2,1024, size=(n,)).tolist()
+# dim1 = torch.randint(2,1024, size=(n,)).tolist()
+# dim2 = torch.randint(2,1024, size=(n,)).tolist()
dim1 = [1]
dim2 = [33]
dtype = [torch.int8]
-#a_order = ['col_turing', 'col_ampere']
-a_order = ['col_turing']
-out_order = ['row']
-values = list(product(dim1,dim2,dtype, a_order, out_order))
-names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values]
-@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
+# a_order = ['col_turing', 'col_ampere']
+a_order = ["col_turing"]
+out_order = ["row"]
+values = list(product(dim1, dim2, dtype, a_order, out_order))
+names = [
+ "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
+ for vals in values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dtype, orderA, orderOut", values, ids=names
+)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
for i in range(1):
- A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype)
+ A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
out2, S2 = F.transform(A, to_order=orderA)
- A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2)
+ A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2)
assert A2.shape[0] == A.shape[0]
assert A2.shape[1] == A.shape[1]
-
- print('')
+ print("")
print(A)
print(out2)
print(A2)
-
- #torch.testing.assert_allclose(A, A2)
-
-
+ # torch.testing.assert_allclose(A, A2)
def test_overflow():
formatB = F.get_special_format_str()
print(formatB)
for i in range(2):
- a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
- b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
+ a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
+ b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1)
- Ca, Sa = F.nvidia_transform(a, 'col32')
+ Ca, Sa = F.nvidia_transform(a, "col32")
Cb, Sb = F.nvidia_transform(b, formatB)
c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
@@ -1198,46 +1390,57 @@ def test_overflow():
n = 2
-dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
-#dim1 = [4]
-#dim2 = [5]
+dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
+# dim1 = [4]
+# dim2 = [5]
+
+values = list(product(dim1, dim2))
+names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
+
-values = list(product(dim1,dim2))
-names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_coo_double_quant(dim1, dim2):
threshold = 3.00
for i in range(k):
- A = torch.randn(dim1, dim2, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
- idx = (torch.abs(A) >= threshold)
+ idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
+ A, threshold=threshold
+ )
if coo_tensor is not None:
- A1 = A*idx
+ A1 = A * idx
A2 = torch.zeros_like(A)
- A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
+ A2[
+ coo_tensor.rowidx.long(), coo_tensor.colidx.long()
+ ] = coo_tensor.values
torch.testing.assert_allclose(A1, A2)
- A1 = A*(idx==0)
- A2 = (CA.float()*statsA.unsqueeze(1)/127).half()
- torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2)
+ A1 = A * (idx == 0)
+ A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
+ torch.testing.assert_allclose(
+ A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
+ )
+
n = 2
-dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
-dim2 = torch.randint(1,1*1024, size=(n,)).tolist()
-#dim1 = [7]
-#dim2 = [11]
+dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
+# dim1 = [7]
+# dim2 = [11]
transposed_B = [False, True]
-values = list(product(dim1,dim2, transposed_B))
-names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, transposed_B))
+names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
def test_spmm_coo(dim1, dim2, transposed_B):
threshold = 1.5
dim3 = torch.randint(32, 128, size=(1,)).item()
- #dim3 = 17
+ # dim3 = 17
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
if transposed_B:
@@ -1249,8 +1452,10 @@ def test_spmm_coo(dim1, dim2, transposed_B):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
if transposed_B:
out2 = F.spmm_coo(cooA, B.t())
@@ -1262,18 +1467,17 @@ def test_spmm_coo(dim1, dim2, transposed_B):
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
-
def test_spmm_bench():
batch = 2
- model = 1024*1
- hidden = model*4
+ model = 1024 * 1
+ hidden = model * 4
seq = 1024
- dim1 = batch*seq
+ dim1 = batch * seq
dim2 = model
dim3 = hidden
threshold = 4
- A = torch.randn(dim1, dim2, device='cuda').half()
- B = torch.randn(dim2, dim3, device='cuda').half()
+ A = torch.randn(dim1, dim2, device="cuda").half()
+ B = torch.randn(dim2, dim3, device="cuda").half()
for i in range(10):
C1 = bnb.matmul(A, B)
@@ -1282,14 +1486,16 @@ def test_spmm_bench():
for i in range(k):
C1 = bnb.matmul(A, B)
torch.cuda.synchronize()
- t8 = time.time()-t0
+ t8 = time.time() - t0
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
- print(nnz/idx.numel())
+ print(nnz / idx.numel())
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
for i in range(10):
out2 = F.spmm_coo(cooA, B)
@@ -1299,20 +1505,22 @@ def test_spmm_bench():
for i in range(k):
out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
- tsp = time.time()-t0
+ tsp = time.time() - t0
print(tsp, t8)
- print(tsp/t8)
+ print(tsp / t8)
n = 2
-dim1 = torch.randint(256,1*1024, size=(n,)).tolist()
-dim2 = torch.randint(256,1*1024, size=(n,)).tolist()
-values = list(product(dim1,dim2))
-names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
+dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
+values = list(product(dim1, dim2))
+names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
def test_integrated_sparse_decomp(dim1, dim2):
threshold = 3.0
- formatB = 'col_turing'
+ formatB = "col_turing"
for i in range(k):
A = torch.randn(dim1, dim2).cuda().half()
w1 = torch.randn(dim1, dim2).cuda().half()
@@ -1322,13 +1530,15 @@ def test_integrated_sparse_decomp(dim1, dim2):
CTw1, Sw1 = F.transform(Cw1, formatB)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
- CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
- C32A, SA = F.transform(CA, 'col32')
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
+ A, threshold=threshold
+ )
+ C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
@@ -1338,8 +1548,8 @@ def test_integrated_sparse_decomp(dim1, dim2):
out4 = F.spmm_coo(coo_tensor, w1.t())
out5 = out3 + out4
- err1 = torch.abs(out1-out2).mean().item()
- err2 = torch.abs(out1-out5).mean().item()
+ err1 = torch.abs(out1 - out2).mean().item()
+ err2 = torch.abs(out1 - out5).mean().item()
assert err2 < err1
@@ -1350,91 +1560,99 @@ def test_matmuls():
c2 = bnb.matmul(a, b)
c3 = bnb.matmul(a, b)
- err1 = torch.abs(c1-c2).mean().item()
- err2 = torch.abs(c1-c3).mean().item()
+ err1 = torch.abs(c1 - c2).mean().item()
+ err2 = torch.abs(c1 - c3).mean().item()
assert err1 < 0.2
assert err2 < 0.2
-
n = 2
-#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
-#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim1 = [1*2048]
+# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1 * 2048]
dim2 = [12288]
-#dim1 = [32]
-#dim2 = [32]
-#dtype = [torch.float16, torch.int8]
+# dim1 = [32]
+# dim2 = [32]
+# dtype = [torch.float16, torch.int8]
dtype = [torch.float16]
-out_function = ['zeros', 'ones']
-values = list(product(dim1,dim2, dtype, out_function))
-names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values]
+out_function = ["zeros", "ones"]
+values = list(product(dim1, dim2, dtype, out_function))
+names = [
+ "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
out_func = getattr(torch, out_func)
threshold = 3.3
- #threshold = 2.8
- #threshold = 0.0
- A = torch.randn(dim1, dim2, device='cuda').half()
+ # threshold = 2.8
+ # threshold = 0.0
+ A = torch.randn(dim1, dim2, device="cuda").half()
if dtype == torch.float16:
- B = torch.randn(dim2, dim2*4, device='cuda').half()
+ B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
else:
- B = torch.randn(dim2, dim2*4, device='cuda').half()
+ B = torch.randn(dim2, dim2 * 4, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
- B, SB = F.vectorwise_quant(B, quant_type='linear')
- #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
+ B, SB = F.vectorwise_quant(B, quant_type="linear")
+ # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
- print('')
+ print("")
idx = torch.abs(A) >= threshold
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
out1 = torch.matmul(A2.half(), B.half())
out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
out1 += out.clone()
out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
- #print(B)
- #print(out1)
- #print(out2)
- p = 200/(2048*12288*4)
+ # print(B)
+ # print(out1)
+ # print(out2)
+ p = 200 / (2048 * 12288 * 4)
n = out1.numel()
- count = math.ceil(p*n)
+ count = math.ceil(p * n)
std = out1.std()
out1 /= std
out2 /= std
- assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
- #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
+ assert_all_approx_close(
+ out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
+ )
+ # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
- #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
+ # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
- #Bt = torch.randn(dim2*4, dim2, device='cuda').half()
- #torch.cuda.synchronize()
- #t0 = time.time()
- #print(A2.shape, B.shape)
- #for i in range(100):
+ # Bt = torch.randn(dim2*4, dim2, device='cuda').half()
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # print(A2.shape, B.shape)
+ # for i in range(100):
# #out3 = F.spmm_coo(cooA, Bt.t())
# #out2 = F.spmm_coo(cooA, B)
# #out2 = F.spmm_coo_very_sparse(cooA, B)
# #out1 = torch.matmul(A, Bt.t())
- #torch.cuda.synchronize()
- #print(time.time() - t0)
+ # torch.cuda.synchronize()
+ # print(time.time() - t0)
+
def test_layout():
- a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16)
- a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte()
- a2, s2 = F.transform(a1, 'col_turing')
+ a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
+ a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
+ a2, s2 = F.transform(a1, "col_turing")
print(a2.shape)
- print(a1.flatten()[8*64:8*64+32])
+ print(a1.flatten()[8 * 64 : 8 * 64 + 32])
for i in range(4):
- print(a2.flatten()[i*8*32:i*8*32+32], 0)
+ print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
def test_coo2csr():
@@ -1444,14 +1662,16 @@ def test_coo2csr():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
csrA = F.coo2csr(cooA)
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
assert counts.numel() == A.shape[0]
- torch.testing.assert_allclose(counts, (A2!=0).sum(1))
- idx = (A2!=0)
+ torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
+ idx = A2 != 0
torch.testing.assert_allclose(A2[idx], csrA.values)
@@ -1462,41 +1682,43 @@ def test_coo2csc():
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
cscA = F.coo2csc(cooA)
counts = cscA.colptr[1:] - cscA.colptr[:-1]
assert counts.numel() == A.shape[1]
- torch.testing.assert_allclose(counts, (A2!=0).sum(0))
+ torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
# torch uses row-major -> use transpose to transfer to col-major
- idx = (A2.t()!=0)
+ idx = A2.t() != 0
torch.testing.assert_allclose(A2.t()[idx], cscA.values)
-
n = 2
-#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
-#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
-dim1 = [1*2048]
-#dim2 = [12288]
+# dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+# dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1 * 2048]
+# dim2 = [12288]
dim2 = [2048]
-#dim1 = [2]
-#dim2 = [2]
+# dim1 = [2]
+# dim2 = [2]
dtype = [torch.int8]
-values = list(product(dim1,dim2, dtype))
-names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values]
+values = list(product(dim1, dim2, dtype))
+names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
+
+
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
def test_spmm_coo_dequant(dim1, dim2, dtype):
threshold = 6.0
- #threshold = 2.8
- #threshold = 0.0
- A = torch.randn(dim1, dim2, device='cuda').half()
- B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16)
+ # threshold = 2.8
+ # threshold = 0.0
+ A = torch.randn(dim1, dim2, device="cuda").half()
+ B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16)
torch.nn.init.xavier_uniform_(B)
Bt = B.t().contiguous()
-
CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
rowidx = torch.randint(0, A.shape[-1], size=(15,))
@@ -1507,12 +1729,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
nnz = (idx == 1).sum().item()
rows, cols = torch.where(idx)
values = A[idx]
- cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- A2 = A*idx
+ cooA = F.COOSparseTensor(
+ A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values
+ )
+ A2 = A * idx
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
out1 = torch.matmul(A2, B.half())
out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
- out3 = out3*statsBt.half()/127
+ out3 = out3 * statsBt.half() / 127
values, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
@@ -1521,56 +1745,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
- p = 200/(2048*12288*4)
+ p = 200 / (2048 * 12288 * 4)
n = out1.numel()
- count = math.ceil(p*n)
+ count = math.ceil(p * n)
assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
-
-
- #torch.cuda.synchronize()
- #t0 = time.time()
- #for i in range(100):
+ # torch.cuda.synchronize()
+ # t0 = time.time()
+ # for i in range(100):
# out2 = F.spmm_coo_very_sparse(cooA, B)
- #torch.cuda.synchronize()
- #print('fp16', time.time() - t0)
+ # torch.cuda.synchronize()
+ # print('fp16', time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = F.spmm_coo(cooA, B)
+ out2 = F.spmm_coo(cooA, B)
torch.cuda.synchronize()
- print('cusparse fp16', time.time() - t0)
+ print("cusparse fp16", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = F.spmm_coo_very_sparse(cooA, CBt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt)
torch.cuda.synchronize()
- print('int8', time.time() - t0)
+ print("int8", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
torch.cuda.synchronize()
- print('int8+dequant', time.time() - t0)
+ print("int8+dequant", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
- out2 = torch.matmul(A, B)
+ out2 = torch.matmul(A, B)
torch.cuda.synchronize()
- print('matmul', time.time() - t0)
+ print("matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
- out = out1+out2
+ out = out1 + out2
torch.cuda.synchronize()
- print('sparse+ matmul', time.time() - t0)
+ print("sparse+ matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
@@ -1578,33 +1800,38 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
out1 = bnb.matmul(A, Bt)
torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
torch.cuda.synchronize()
- print('partial matmul', time.time() - t0)
+ print("partial matmul", time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
out1 = bnb.matmul(A, Bt)
torch.cuda.synchronize()
- print('partial matmul', time.time() - t0)
+ print("partial matmul", time.time() - t0)
+
batch_size = 1
seqdim = 2048
values = []
-values.append((batch_size, seqdim, 768, 4*768))
-#values.append((batch_size, seqdim, 1024, 4*1024))
-#values.append((batch_size, seqdim, 1536, 4*1536))
-#values.append((batch_size, seqdim, 2048, 4*2048))
-#values.append((batch_size, seqdim, 2560, 4*2560))
-#values.append((batch_size, seqdim, 4096, 4*4096))
-#values.append((batch_size, seqdim, 5140, 4*5140))
-#values.append((batch_size, seqdim, 12288, 4*12288))
-names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
+values.append((batch_size, seqdim, 768, 4 * 768))
+# values.append((batch_size, seqdim, 1024, 4*1024))
+# values.append((batch_size, seqdim, 1536, 4*1536))
+# values.append((batch_size, seqdim, 2048, 4*2048))
+# values.append((batch_size, seqdim, 2560, 4*2560))
+# values.append((batch_size, seqdim, 4096, 4*4096))
+# values.append((batch_size, seqdim, 5140, 4*5140))
+# values.append((batch_size, seqdim, 12288, 4*12288))
+names = [
+ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
+]
+
+
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
formatB = F.get_special_format_str()
- A = torch.randn(batch, seq, model, device='cuda').half()
- B = torch.empty(hidden, model, dtype=torch.float16, device='cuda')
+ A = torch.randn(batch, seq, model, device="cuda").half()
+ B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
@@ -1613,31 +1840,37 @@ def test_bench_matmul(batch, seq, model, hidden):
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
- linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
+ linearMixedBit = (
+ bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
+ )
linearMixedBit.eval()
# warmup
for i in range(100):
torch.matmul(A, B.t())
torch.cuda.synchronize()
- print('')
+ print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
torch.matmul(A, B.t())
torch.cuda.synchronize()
- print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
bnb.matmul(A, B)
torch.cuda.synchronize()
- print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
- C32A, SA = F.transform(CA, 'col32')
+ C32A, SA = F.transform(CA, "col32")
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
CxB, SB = F.transform(CB, to_order=formatB)
torch.cuda.synchronize()
@@ -1645,7 +1878,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
torch.cuda.synchronize()
- print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
BA, statsB = F.vectorwise_quant(B, dim=1)
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
@@ -1654,26 +1889,30 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous()
CA, statsA = F.vectorwise_quant(A2, dim=1)
- C32A, SA = F.nvidia_transform(CA, 'col32')
+ C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
- Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
+ Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize()
- print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
- BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+ BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
A2 = A.view(-1, A.shape[-1]).contiguous()
- CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear')
- C32A, SA = F.nvidia_transform(CA, 'col32')
+ CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
+ C32A, SA = F.nvidia_transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
- Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
- out = Cout*statsB*statsA*(1.0/(127*127))
+ Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
+ out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize()
- print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
linear8bit(A)
torch.cuda.synchronize()
@@ -1681,8 +1920,9 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
linear8bit(A)
torch.cuda.synchronize()
- print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
-
+ print(
+ f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
linearMixedBit(A)
torch.cuda.synchronize()
@@ -1690,65 +1930,66 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(100):
linearMixedBit(A)
torch.cuda.synchronize()
- print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+ print(
+ f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
+ )
def test_zeropoint():
def min_max(x):
maxA = torch.amax(x, dim=1, keepdim=True)
minA = torch.amin(x, dim=1, keepdim=True)
- midpoint = (maxA-minA)/2.0
- dyna = 252/(maxA-minA)
- #dyna *= 0.98
- x = dyna*x
- x = x - torch.round((dyna*(minA+midpoint)))
+ midpoint = (maxA - minA) / 2.0
+ dyna = 252 / (maxA - minA)
+ # dyna *= 0.98
+ x = dyna * x
+ x = x - torch.round((dyna * (minA + midpoint)))
return x.to(torch.int8), minA, midpoint, dyna
+
batch = 2
seq = 2
model = 4
- hidden = 2*model
- #batch = 4
- #seq = 2048
- #model = 1024
- #hidden = 8*model
- A = torch.randn(batch*seq, model, device='cuda').half()-0.4
- B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half())
-
- #A[0] = 0
- #B[:, 0] = 0
- #A = A*(A>0)
- #A[0, 0] = 0
- #A[0, 0] = 6.0
+ hidden = 2 * model
+ # batch = 4
+ # seq = 2048
+ # model = 1024
+ # hidden = 8*model
+ A = torch.randn(batch * seq, model, device="cuda").half() - 0.4
+ B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half())
+
+ # A[0] = 0
+ # B[:, 0] = 0
+ # A = A*(A>0)
+ # A[0, 0] = 0
+ # A[0, 0] = 6.0
Ac, minA, midpoint, dyna = min_max(A)
- #print(Ac[0, 0], 'zero')
- #print(Ac, Ac.min(), Ac.max())
- Bc, maxB = F.vectorwise_quant(B, quant_type='linear')
+ # print(Ac[0, 0], 'zero')
+ # print(Ac, Ac.min(), Ac.max())
+ Bc, maxB = F.vectorwise_quant(B, quant_type="linear")
out = F.igemm(Ac, Bc)
- out2 = torch.matmul(A,B)
- offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna
+ out2 = torch.matmul(A, B)
+ offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna
out = out.float()
- #print(out.shape, maxB.shape, scale.shape, offset.shape)
- norm1 = maxB/127
- C4 = (out/dyna)*norm1+offset
-
+ # print(out.shape, maxB.shape, scale.shape, offset.shape)
+ norm1 = maxB / 127
+ C4 = (out / dyna) * norm1 + offset
B1 = torch.nn.Parameter(B.clone())
B2 = torch.nn.Parameter(B.clone())
B3 = torch.nn.Parameter(B.clone())
B4 = torch.nn.Parameter(B.clone())
-
C1 = torch.matmul(A, B1)
- C2 = bnb.matmul_cublas(A, B2, None, 'linear')
- C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint')
- C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint')
+ C2 = bnb.matmul_cublas(A, B2, None, "linear")
+ C3 = bnb.matmul_cublas(A, B3, None, "zeropoint")
+ C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint")
- err1 = torch.abs(C1-C2).mean().item()
- err2 = torch.abs(C1-C3).mean().item()
- err3 = torch.abs(C1-C4).mean().item()
+ err1 = torch.abs(C1 - C2).mean().item()
+ err2 = torch.abs(C1 - C3).mean().item()
+ err3 = torch.abs(C1 - C4).mean().item()
print(err1, err2, err3)
- #assert err1 > err2
+ # assert err1 > err2
loss1 = C1.mean()
loss2 = C2.mean()
@@ -1765,40 +2006,38 @@ def test_zeropoint():
print(B2.grad)
print(B3.grad)
print(B4.grad)
- err1 = torch.abs(B1.grad-B2.grad).mean().item()
- err2 = torch.abs(B1.grad-B3.grad).mean().item()
- err3 = torch.abs(B1.grad-B4.grad).mean().item()
+ err1 = torch.abs(B1.grad - B2.grad).mean().item()
+ err2 = torch.abs(B1.grad - B3.grad).mean().item()
+ err3 = torch.abs(B1.grad - B4.grad).mean().item()
print(err1, err2, err3)
-
-
def test_zp():
def quant_zp(x):
dtype = x.dtype
x = x.float()
dyna = x.max() - x.min()
- if dyna == 0: dyna = 1
- qx = 254./dyna
+ if dyna == 0:
+ dyna = 1
+ qx = 254.0 / dyna
minx = x.min()
- #zpx = torch.round(minx* qx)
- #zpx = 127 - torch.round(x.max()* qx)
- zpx = torch.round(x.min()* qx) - 127
- x = (qx*x) + zpx
+ # zpx = torch.round(minx* qx)
+ # zpx = 127 - torch.round(x.max()* qx)
+ zpx = torch.round(x.min() * qx) - 127
+ x = (qx * x) + zpx
return x, qx, zpx
+
batch = 2
seq = 512
model = 1024
- hidden = 4*model
- A = torch.randn(batch*seq, model, device='cuda').half()*0.1
- B = torch.randn(model, hidden, device='cuda').half()*0.1
-
+ hidden = 4 * model
+ A = torch.randn(batch * seq, model, device="cuda").half() * 0.1
+ B = torch.randn(model, hidden, device="cuda").half() * 0.1
C0 = torch.matmul(A, B)
-
- #A, SA = F.vectorwise_quant(A, quant_type='linear')
- #B, SB = F.vectorwise_quant(B, quant_type='linear')
+ # A, SA = F.vectorwise_quant(A, quant_type='linear')
+ # B, SB = F.vectorwise_quant(B, quant_type='linear')
A = A.float()
B = B.float()
@@ -1806,69 +2045,68 @@ def test_zp():
C3 = bnb.matmul(A.half(), B.t().contiguous().half())
zp = 1
- #C2 = torch.matmul(A-zp, B)
- #C2 += B.sum(0).view(1, -1)*zp
- C2 = torch.matmul(A, B-zp)
- C2 -= A.sum(1).view(-1, 1)*zp
+ # C2 = torch.matmul(A-zp, B)
+ # C2 += B.sum(0).view(1, -1)*zp
+ C2 = torch.matmul(A, B - zp)
+ C2 -= A.sum(1).view(-1, 1) * zp
ca, cqa, cza = quant_zp(A)
print(ca.min(), ca.max())
- print((ca-cza).min(), (ca-cza).max())
+ print((ca - cza).min(), (ca - cza).max())
zp = 1
scale = 2.0
- C5 = torch.matmul((A*scale)-zp, B)
- C5 += B.sum(0)*zp
+ C5 = torch.matmul((A * scale) - zp, B)
+ C5 += B.sum(0) * zp
C5 /= scale
CA, qa, zpa = quant_zp(A)
C4 = torch.matmul(CA, B)
- C4 -= B.sum(0)*zpa
+ C4 -= B.sum(0) * zpa
C4 /= qa
zpb = 1
zpa = 1
qa = 2
qb = 2
- C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb)
- C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
- C6 -= zpa*zpb*A.shape[1]
- C6 /= qa*qb
+ C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb)
+ C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
+ C6 -= zpa * zpb * A.shape[1]
+ C6 /= qa * qb
CA, qa, zpa = quant_zp(A)
CB, qb, zpb = quant_zp(B)
C7 = torch.matmul(CA, CB)
- C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
- C7 -= zpa*zpb*A.shape[1]
- C7 /= qa*qb
+ C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb)
+ C7 -= zpa * zpb * A.shape[1]
+ C7 /= qa * qb
- print('')
- #print(C0.flatten()[:10])
+ print("")
+ # print(C0.flatten()[:10])
print(C1.flatten()[:10])
print(C2.flatten()[:10])
print(C3.flatten()[:10])
print(C5.flatten()[:10])
print(C6.flatten()[:10])
print(C7.flatten()[:10])
- err1 = torch.abs(C1-C2).mean().item()
- err2 = torch.abs(C1-C3).mean().item()
- err3 = torch.abs(C1-C4).mean().item()
- err4 = torch.abs(C1-C5).mean().item()
- err5 = torch.abs(C1-C6).mean().item()
- err6 = torch.abs(C1-C7).mean().item()
+ err1 = torch.abs(C1 - C2).mean().item()
+ err2 = torch.abs(C1 - C3).mean().item()
+ err3 = torch.abs(C1 - C4).mean().item()
+ err4 = torch.abs(C1 - C5).mean().item()
+ err5 = torch.abs(C1 - C6).mean().item()
+ err6 = torch.abs(C1 - C7).mean().item()
print(err1, err2, err3, err4, err5, err6)
-
def test_extract_outliers():
for i in range(k):
- shapeA = (4096, 4096*4)
+ shapeA = (4096, 4096 * 4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
- #idx = torch.Tensor([0]).int().cuda()
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ # idx = torch.Tensor([0]).int().cuda()
+ A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
outliers1 = A[:, idx.long()]
- CA, SA = F.transform(A, 'col_turing')
+ CA, SA = F.transform(A, "col_turing")
outliers2 = F.extract_outliers(CA, SA, idx)
@@ -1877,7 +2115,7 @@ def test_extract_outliers():
torch.testing.assert_allclose(outliers1, outliers2)
- CA, SA = F.transform(A, 'col_ampere')
+ CA, SA = F.transform(A, "col_ampere")
outliers2 = F.extract_outliers(CA, SA, idx)