summaryrefslogtreecommitdiff
path: root/quicktest.py
blob: 29d045db2c296b462dff0cf2a111ee850314004a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from itertools import product

import torch

import bitsandbytes as bnb
import bitsandbytes.functional as F


def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
    k = 25
    for i in range(k):
        if dims == 2:
            A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
                torch.int8
            )
        elif dims == 3:
            A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
                torch.int8
            )
        B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
        C1 = torch.matmul(A.float(), B.t().float())

        A2, SA = F.transform(A, "col32")
        B2, SB = F.transform(B, "colx")
        if dims == 2:
            C2, SC = F.transform(
                torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"),
                "col32",
            )
        else:
            C2, SC = F.transform(
                torch.zeros(
                    A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda"
                ),
                "col32",
            )
        F.igemmlt(A2, B2, C2, SA, SB, SC)
        C3, S = F.transform(C2, "row", state=SC)
        # torch.testing.assert_allclose(C1, C3.float())
        # print(C1)
        # print(C2)
        # print(C3)
        allclose = torch.allclose(C1, C3.float())
        if allclose:
            print(C1)
            print(C2)
            print(C3)

        ## transposed
        # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
        # if dims == 2:
        #    B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
        #    C1 = torch.matmul(A.float(), B.float().t())
        # elif dims == 3:
        #    B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
        #    C1 = torch.matmul(B.float(), A.t().float())
        #    C1 = C1.permute([2, 0, 1])

        # A2, SA = F.transform(A, 'col32')
        # B2, SB = F.transform(B, 'colx')
        # if dims == 2:
        #    C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
        # else:
        #    C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
        #    state = (C2.shape, 'row', A.shape[0])
        #    C2, SC = F.transform(C2, 'col32', state=state)
        # F.igemmlt(A2, B2, C2, SA, SB, SC)
        # C3, S = F.transform(C2, 'row', state=SC, ld=[0])
        # torch.testing.assert_allclose(C1, C3.float())

        ## weight update
        # if dims == 3:
        #    A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
        #    B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
        #    C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())

        #    A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
        #    B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
        #    C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
        #    C2, SC = F.transform(C2, 'col32')
        #    F.igemmlt(B2, A2, C2, SB, SA, SC)
        #    C3, S = F.transform(C2, 'row', state=SC)
        #    torch.testing.assert_allclose(C1, C3.float())


dims = (2, 3)
ldb = [0]

n = 2
dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32, 1024, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))

for ldb in range(32, 4096, 32):
    # for ldb in [None]:
    val = test_igemmlt(2, 2, 2, 2, 2, ldb)
    if val:
        print(val, ldb)
    else:
        print("nope", ldb)
# for val in values:
# test_igemmlt(*val)