summaryrefslogtreecommitdiff
path: root/quicktest.py
diff options
context:
space:
mode:
Diffstat (limited to 'quicktest.py')
-rw-r--r--quicktest.py90
1 files changed, 90 insertions, 0 deletions
diff --git a/quicktest.py b/quicktest.py
new file mode 100644
index 0000000..2db6afa
--- /dev/null
+++ b/quicktest.py
@@ -0,0 +1,90 @@
+import torch
+import bitsandbytes as bnb
+import bitsandbytes.functional as F
+
+from itertools import product
+
+def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
+ k = 25
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ C1 = torch.matmul(A.float(), B.t().float())
+
+ A2, SA = F.transform(A, 'col32')
+ B2, SB = F.transform(B, 'colx')
+ if dims == 2:
+ C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ else:
+ C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ F.igemmlt(A2, B2, C2, SA, SB, SC)
+ C3, S = F.transform(C2, 'row', state=SC)
+ #torch.testing.assert_allclose(C1, C3.float())
+ #print(C1)
+ #print(C2)
+ #print(C3)
+ allclose = torch.allclose(C1, C3.float())
+ if allclose:
+ print(C1)
+ print(C2)
+ print(C3)
+
+ ## transposed
+ #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ #if dims == 2:
+ # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(A.float(), B.float().t())
+ #elif dims == 3:
+ # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(B.float(), A.t().float())
+ # C1 = C1.permute([2, 0, 1])
+
+ #A2, SA = F.transform(A, 'col32')
+ #B2, SB = F.transform(B, 'colx')
+ #if dims == 2:
+ # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ #else:
+ # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
+ # state = (C2.shape, 'row', A.shape[0])
+ # C2, SC = F.transform(C2, 'col32', state=state)
+ #F.igemmlt(A2, B2, C2, SA, SB, SC)
+ #C3, S = F.transform(C2, 'row', state=SC, ld=[0])
+ #torch.testing.assert_allclose(C1, C3.float())
+
+ ## weight update
+ #if dims == 3:
+ # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
+
+ # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
+ # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
+ # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
+ # C2, SC = F.transform(C2, 'col32')
+ # F.igemmlt(B2, A2, C2, SB, SA, SC)
+ # C3, S = F.transform(C2, 'row', state=SC)
+ # torch.testing.assert_allclose(C1, C3.float())
+
+
+dims = (2, 3)
+ldb = [0]
+
+n = 2
+dim1 = torch.randint(1,256, size=(n,)).tolist()
+dim2 = torch.randint(32,512, size=(n,)).tolist()
+dim3 = torch.randint(32,1024, size=(n,)).tolist()
+dim4 = torch.randint(32,1024, size=(n,)).tolist()
+values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
+
+for ldb in range(32, 4096, 32):
+#for ldb in [None]:
+ val = test_igemmlt(2, 2, 2, 2, 2, ldb)
+ if val:
+ print(val, ldb)
+ else:
+ print('nope', ldb)
+#for val in values:
+ #test_igemmlt(*val)