summaryrefslogtreecommitdiff
path: root/quicktest.py
diff options
context:
space:
mode:
Diffstat (limited to 'quicktest.py')
-rw-r--r--quicktest.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/quicktest.py b/quicktest.py
index 29d045d..0fcda64 100644
--- a/quicktest.py
+++ b/quicktest.py
@@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
torch.int8
)
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
- torch.int8
- )
- B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
+ A = torch.randint(
+ -128, 127, size=(dim1, dim2, dim3), device="cuda"
+ ).to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "colx")
if dims == 2:
C2, SC = F.transform(
- torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"),
+ torch.zeros(
+ A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"
+ ),
"col32",
)
else:
C2, SC = F.transform(
torch.zeros(
- A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda"
+ A.shape[0],
+ A.shape[1],
+ B.shape[0],
+ dtype=torch.int32,
+ device="cuda",
),
"col32",
)