From 59a615b3869eb8488a748e2aa51224a5e3d366bb Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Tue, 2 Aug 2022 21:26:50 -0700 Subject: factored cuda_setup.main out into smaller modules and functions --- quicktest.py | 112 ----------------------------------------------------------- 1 file changed, 112 deletions(-) delete mode 100644 quicktest.py (limited to 'quicktest.py') diff --git a/quicktest.py b/quicktest.py deleted file mode 100644 index 0fcda64..0000000 --- a/quicktest.py +++ /dev/null @@ -1,112 +0,0 @@ -from itertools import product - -import torch - -import bitsandbytes as bnb -import bitsandbytes.functional as F - - -def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): - k = 25 - for i in range(k): - if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( - torch.int8 - ) - elif dims == 3: - A = torch.randint( - -128, 127, size=(dim1, dim2, dim3), device="cuda" - ).to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( - torch.int8 - ) - C1 = torch.matmul(A.float(), B.t().float()) - - A2, SA = F.transform(A, "col32") - B2, SB = F.transform(B, "colx") - if dims == 2: - C2, SC = F.transform( - torch.zeros( - A.shape[0], B.shape[0], dtype=torch.int32, device="cuda" - ), - "col32", - ) - else: - C2, SC = F.transform( - torch.zeros( - A.shape[0], - A.shape[1], - B.shape[0], - dtype=torch.int32, - device="cuda", - ), - "col32", - ) - F.igemmlt(A2, B2, C2, SA, SB, SC) - C3, S = F.transform(C2, "row", state=SC) - # torch.testing.assert_allclose(C1, C3.float()) - # print(C1) - # print(C2) - # print(C3) - allclose = torch.allclose(C1, C3.float()) - if allclose: - print(C1) - print(C2) - print(C3) - - ## transposed - # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) - # if dims == 2: - # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) - # C1 = torch.matmul(A.float(), B.float().t()) - # elif dims == 3: - # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) - # C1 = torch.matmul(B.float(), A.t().float()) - # C1 = C1.permute([2, 0, 1]) - - # A2, SA = F.transform(A, 'col32') - # B2, SB = F.transform(B, 'colx') - # if dims == 2: - # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') - # else: - # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') - # state = (C2.shape, 'row', A.shape[0]) - # C2, SC = F.transform(C2, 'col32', state=state) - # F.igemmlt(A2, B2, C2, SA, SB, SC) - # C3, S = F.transform(C2, 'row', state=SC, ld=[0]) - # torch.testing.assert_allclose(C1, C3.float()) - - ## weight update - # if dims == 3: - # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) - # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) - # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) - - # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx') - # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32') - # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda') - # C2, SC = F.transform(C2, 'col32') - # F.igemmlt(B2, A2, C2, SB, SA, SC) - # C3, S = F.transform(C2, 'row', state=SC) - # torch.testing.assert_allclose(C1, C3.float()) - - -dims = (2, 3) -ldb = [0] - -n = 2 -dim1 = torch.randint(1, 256, size=(n,)).tolist() -dim2 = torch.randint(32, 512, size=(n,)).tolist() -dim3 = torch.randint(32, 1024, size=(n,)).tolist() -dim4 = torch.randint(32, 1024, size=(n,)).tolist() -values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) - -for ldb in range(32, 4096, 32): - # for ldb in [None]: - val = test_igemmlt(2, 2, 2, 2, 2, ldb) - if val: - print(val, ldb) - else: - print("nope", ldb) -# for val in values: -# test_igemmlt(*val) -- cgit v1.2.3