import torch import bitsandbytes as bnb import bitsandbytes.functional as F from itertools import product def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): k = 25 for i in range(k): if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) elif dims == 3: A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, 'col32') B2, SB = F.transform(B, 'colx') if dims == 2: C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') else: C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') F.igemmlt(A2, B2, C2, SA, SB, SC) C3, S = F.transform(C2, 'row', state=SC) #torch.testing.assert_allclose(C1, C3.float()) #print(C1) #print(C2) #print(C3) allclose = torch.allclose(C1, C3.float()) if allclose: print(C1) print(C2) print(C3) ## transposed #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) #if dims == 2: # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) # C1 = torch.matmul(A.float(), B.float().t()) #elif dims == 3: # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # C1 = torch.matmul(B.float(), A.t().float()) # C1 = C1.permute([2, 0, 1]) #A2, SA = F.transform(A, 'col32') #B2, SB = F.transform(B, 'colx') #if dims == 2: # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') #else: # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') # state = (C2.shape, 'row', A.shape[0]) # C2, SC = F.transform(C2, 'col32', state=state) #F.igemmlt(A2, B2, C2, SA, SB, SC) #C3, S = F.transform(C2, 'row', state=SC, ld=[0]) #torch.testing.assert_allclose(C1, C3.float()) ## weight update #if dims == 3: # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx') # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32') # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda') # C2, SC = F.transform(C2, 'col32') # F.igemmlt(B2, A2, C2, SB, SA, SC) # C3, S = F.transform(C2, 'row', state=SC) # torch.testing.assert_allclose(C1, C3.float()) dims = (2, 3) ldb = [0] n = 2 dim1 = torch.randint(1,256, size=(n,)).tolist() dim2 = torch.randint(32,512, size=(n,)).tolist() dim3 = torch.randint(32,1024, size=(n,)).tolist() dim4 = torch.randint(32,1024, size=(n,)).tolist() values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) for ldb in range(32, 4096, 32): #for ldb in [None]: val = test_igemmlt(2, 2, 2, 2, 2, ldb) if val: print(val, ldb) else: print('nope', ldb) #for val in values: #test_igemmlt(*val)