summaryrefslogtreecommitdiff
path: root/quicktest.py
diff options
context:
space:
mode:
Diffstat (limited to 'quicktest.py')
-rw-r--r--quicktest.py112
1 files changed, 0 insertions, 112 deletions
diff --git a/quicktest.py b/quicktest.py
deleted file mode 100644
index 0fcda64..0000000
--- a/quicktest.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from itertools import product
-
-import torch
-
-import bitsandbytes as bnb
-import bitsandbytes.functional as F
-
-
-def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
- k = 25
- for i in range(k):
- if dims == 2:
- A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
- torch.int8
- )
- elif dims == 3:
- A = torch.randint(
- -128, 127, size=(dim1, dim2, dim3), device="cuda"
- ).to(torch.int8)
- B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
- torch.int8
- )
- C1 = torch.matmul(A.float(), B.t().float())
-
- A2, SA = F.transform(A, "col32")
- B2, SB = F.transform(B, "colx")
- if dims == 2:
- C2, SC = F.transform(
- torch.zeros(
- A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"
- ),
- "col32",
- )
- else:
- C2, SC = F.transform(
- torch.zeros(
- A.shape[0],
- A.shape[1],
- B.shape[0],
- dtype=torch.int32,
- device="cuda",
- ),
- "col32",
- )
- F.igemmlt(A2, B2, C2, SA, SB, SC)
- C3, S = F.transform(C2, "row", state=SC)
- # torch.testing.assert_allclose(C1, C3.float())
- # print(C1)
- # print(C2)
- # print(C3)
- allclose = torch.allclose(C1, C3.float())
- if allclose:
- print(C1)
- print(C2)
- print(C3)
-
- ## transposed
- # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
- # if dims == 2:
- # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
- # C1 = torch.matmul(A.float(), B.float().t())
- # elif dims == 3:
- # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- # C1 = torch.matmul(B.float(), A.t().float())
- # C1 = C1.permute([2, 0, 1])
-
- # A2, SA = F.transform(A, 'col32')
- # B2, SB = F.transform(B, 'colx')
- # if dims == 2:
- # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
- # else:
- # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
- # state = (C2.shape, 'row', A.shape[0])
- # C2, SC = F.transform(C2, 'col32', state=state)
- # F.igemmlt(A2, B2, C2, SA, SB, SC)
- # C3, S = F.transform(C2, 'row', state=SC, ld=[0])
- # torch.testing.assert_allclose(C1, C3.float())
-
- ## weight update
- # if dims == 3:
- # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
- # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
- # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
-
- # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
- # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
- # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
- # C2, SC = F.transform(C2, 'col32')
- # F.igemmlt(B2, A2, C2, SB, SA, SC)
- # C3, S = F.transform(C2, 'row', state=SC)
- # torch.testing.assert_allclose(C1, C3.float())
-
-
-dims = (2, 3)
-ldb = [0]
-
-n = 2
-dim1 = torch.randint(1, 256, size=(n,)).tolist()
-dim2 = torch.randint(32, 512, size=(n,)).tolist()
-dim3 = torch.randint(32, 1024, size=(n,)).tolist()
-dim4 = torch.randint(32, 1024, size=(n,)).tolist()
-values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
-
-for ldb in range(32, 4096, 32):
- # for ldb in [None]:
- val = test_igemmlt(2, 2, 2, 2, 2, ldb)
- if val:
- print(val, ldb)
- else:
- print("nope", ldb)
-# for val in values:
-# test_igemmlt(*val)