summaryrefslogtreecommitdiff
path: root/quicktest.py
diff options
context:
space:
mode:
Diffstat (limited to 'quicktest.py')
-rw-r--r--quicktest.py80
1 files changed, 47 insertions, 33 deletions
diff --git a/quicktest.py b/quicktest.py
index 2db6afa..29d045d 100644
--- a/quicktest.py
+++ b/quicktest.py
@@ -1,31 +1,45 @@
+from itertools import product
+
import torch
+
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from itertools import product
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
k = 25
for i in range(k):
if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
+ torch.int8
+ )
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
+ torch.int8
+ )
+ B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.t().float())
- A2, SA = F.transform(A, 'col32')
- B2, SB = F.transform(B, 'colx')
+ A2, SA = F.transform(A, "col32")
+ B2, SB = F.transform(B, "colx")
if dims == 2:
- C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ C2, SC = F.transform(
+ torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"),
+ "col32",
+ )
else:
- C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ C2, SC = F.transform(
+ torch.zeros(
+ A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda"
+ ),
+ "col32",
+ )
F.igemmlt(A2, B2, C2, SA, SB, SC)
- C3, S = F.transform(C2, 'row', state=SC)
- #torch.testing.assert_allclose(C1, C3.float())
- #print(C1)
- #print(C2)
- #print(C3)
+ C3, S = F.transform(C2, "row", state=SC)
+ # torch.testing.assert_allclose(C1, C3.float())
+ # print(C1)
+ # print(C2)
+ # print(C3)
allclose = torch.allclose(C1, C3.float())
if allclose:
print(C1)
@@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print(C3)
## transposed
- #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
- #if dims == 2:
+ # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ # if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t())
- #elif dims == 3:
+ # elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1])
- #A2, SA = F.transform(A, 'col32')
- #B2, SB = F.transform(B, 'colx')
- #if dims == 2:
+ # A2, SA = F.transform(A, 'col32')
+ # B2, SB = F.transform(B, 'colx')
+ # if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
- #else:
+ # else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state)
- #F.igemmlt(A2, B2, C2, SA, SB, SC)
- #C3, S = F.transform(C2, 'row', state=SC, ld=[0])
- #torch.testing.assert_allclose(C1, C3.float())
+ # F.igemmlt(A2, B2, C2, SA, SB, SC)
+ # C3, S = F.transform(C2, 'row', state=SC, ld=[0])
+ # torch.testing.assert_allclose(C1, C3.float())
## weight update
- #if dims == 3:
+ # if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
@@ -73,18 +87,18 @@ dims = (2, 3)
ldb = [0]
n = 2
-dim1 = torch.randint(1,256, size=(n,)).tolist()
-dim2 = torch.randint(32,512, size=(n,)).tolist()
-dim3 = torch.randint(32,1024, size=(n,)).tolist()
-dim4 = torch.randint(32,1024, size=(n,)).tolist()
-values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
+dim1 = torch.randint(1, 256, size=(n,)).tolist()
+dim2 = torch.randint(32, 512, size=(n,)).tolist()
+dim3 = torch.randint(32, 1024, size=(n,)).tolist()
+dim4 = torch.randint(32, 1024, size=(n,)).tolist()
+values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
for ldb in range(32, 4096, 32):
-#for ldb in [None]:
+ # for ldb in [None]:
val = test_igemmlt(2, 2, 2, 2, 2, ldb)
if val:
print(val, ldb)
else:
- print('nope', ldb)
-#for val in values:
- #test_igemmlt(*val)
+ print("nope", ldb)
+# for val in values:
+# test_igemmlt(*val)