From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- quicktest.py | 80 +++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 47 insertions(+), 33 deletions(-) (limited to 'quicktest.py') diff --git a/quicktest.py b/quicktest.py index 2db6afa..29d045d 100644 --- a/quicktest.py +++ b/quicktest.py @@ -1,31 +1,45 @@ +from itertools import product + import torch + import bitsandbytes as bnb import bitsandbytes.functional as F -from itertools import product def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): k = 25 for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( + torch.int8 + ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + torch.int8 + ) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - A2, SA = F.transform(A, 'col32') - B2, SB = F.transform(B, 'colx') + A2, SA = F.transform(A, "col32") + B2, SB = F.transform(B, "colx") if dims == 2: - C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + C2, SC = F.transform( + torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"), + "col32", + ) else: - C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + C2, SC = F.transform( + torch.zeros( + A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda" + ), + "col32", + ) F.igemmlt(A2, B2, C2, SA, SB, SC) - C3, S = F.transform(C2, 'row', state=SC) - #torch.testing.assert_allclose(C1, C3.float()) - #print(C1) - #print(C2) - #print(C3) + C3, S = F.transform(C2, "row", state=SC) + # torch.testing.assert_allclose(C1, C3.float()) + # print(C1) + # print(C2) + # print(C3) allclose = torch.allclose(C1, C3.float()) if allclose: print(C1) @@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): print(C3) ## transposed - #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) - #if dims == 2: + # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + # if dims == 2: # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) # C1 = torch.matmul(A.float(), B.float().t()) - #elif dims == 3: + # elif dims == 3: # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # C1 = torch.matmul(B.float(), A.t().float()) # C1 = C1.permute([2, 0, 1]) - #A2, SA = F.transform(A, 'col32') - #B2, SB = F.transform(B, 'colx') - #if dims == 2: + # A2, SA = F.transform(A, 'col32') + # B2, SB = F.transform(B, 'colx') + # if dims == 2: # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') - #else: + # else: # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') # state = (C2.shape, 'row', A.shape[0]) # C2, SC = F.transform(C2, 'col32', state=state) - #F.igemmlt(A2, B2, C2, SA, SB, SC) - #C3, S = F.transform(C2, 'row', state=SC, ld=[0]) - #torch.testing.assert_allclose(C1, C3.float()) + # F.igemmlt(A2, B2, C2, SA, SB, SC) + # C3, S = F.transform(C2, 'row', state=SC, ld=[0]) + # torch.testing.assert_allclose(C1, C3.float()) ## weight update - #if dims == 3: + # if dims == 3: # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) @@ -73,18 +87,18 @@ dims = (2, 3) ldb = [0] n = 2 -dim1 = torch.randint(1,256, size=(n,)).tolist() -dim2 = torch.randint(32,512, size=(n,)).tolist() -dim3 = torch.randint(32,1024, size=(n,)).tolist() -dim4 = torch.randint(32,1024, size=(n,)).tolist() -values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) +dim1 = torch.randint(1, 256, size=(n,)).tolist() +dim2 = torch.randint(32, 512, size=(n,)).tolist() +dim3 = torch.randint(32, 1024, size=(n,)).tolist() +dim4 = torch.randint(32, 1024, size=(n,)).tolist() +values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) for ldb in range(32, 4096, 32): -#for ldb in [None]: + # for ldb in [None]: val = test_igemmlt(2, 2, 2, 2, 2, ldb) if val: print(val, ldb) else: - print('nope', ldb) -#for val in values: - #test_igemmlt(*val) + print("nope", ldb) +# for val in values: +# test_igemmlt(*val) -- cgit v1.2.3