summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
blob: a2c950b74b69c79ec70ec0a557a4104f19b3d035 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
import pytest
import torch

from itertools import product
from torch import nn

import bitsandbytes as bnb

class MockArgs(object):
    def __init__(self, initial_data):
        for key in initial_data:
            setattr(self, key, initial_data[key])

class MLP8bit(torch.nn.Module):
    def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
        super(MLP8bit, self).__init__()
        self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
        self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def get_args():
    args = MockArgs([])
    args.quant_type = 'vector'
    args.use_8bit_training = 'full'
    args.clip_freq = 9999
    return args

def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
    idx = torch.isclose(a, b, rtol, atol)
    sumval = (idx==0).sum().item()
    if sumval > count:
        print(f'Too many values not close: assert {sumval} < {count}')
        torch.testing.assert_allclose(a, b, rtol, atol)

class LinearFunction(torch.autograd.Function):

    @staticmethod
    def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
        norm = math.sqrt(math.pi)/math.sqrt(2.0)
        #std = torch.abs(x).mean()*norm
        std = torch.std(x)
        max1 = std*trim_value
        x = x/max1*127
        x = round_func(x)
        x[x > 127] = 127
        x[x < -127] = -127
        x = x/127*max1

        return x

    def quant(x, quant_type, dim=1):
        if quant_type == 'linear':
            max1 = torch.abs(x).max().float()
            xq = torch.round(x/max1*127).to(torch.int8)
            return xq, max1
        elif quant_type == 'vector':
            max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
            xq = torch.round(x/max1*127).to(torch.int8)
            return xq, max1
        elif quant_type == 'min-max':
            maxA = torch.amax(x, dim=dim, keepdim=True).float()
            minA = torch.amin(x, dim=dim, keepdim=True).float()
            scale = (maxA-minA)/2.0
            xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
            return xq, (minA.float(), scale.float())
        else: return None

    def dequant(xq, S1, S2, dtype, quant_type):
        if quant_type == 'linear':
            norm = S1*S2/(127*127)
            # double cast needed to prevent overflows
            return (xq.float()*norm).to(dtype)
        elif quant_type == 'vector':
            x = xq.float()
            if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
            if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
            #print(x.shape, S1.shape, S2.shape)
            if len(S1.shape) == 2:
                x *= S1.t()/127
            else:
                x *= S1/127
            x *= S2/127
            return x.to(dtype)
        else: return None

    def dequant_min_max(xq, A, B, SA, SB, dtype):
        offset = B.float().t().sum(0)*(SA[0]+SA[1])
        x = xq.float()
        if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
        if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
        if len(SB.shape) == 2:
            x *= SB.t()/127
        else:
            x *= SB/127
        x *= SA[1]/127
        x +=offset
        return x.to(dtype)


    def get_8bit_linear(x, stochastic=False):
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
        max1 = torch.abs(x).max()
        x = x/max1*127
        x = round_func(x)/127*max1
        #x = torch.round(x)/128*max1
        return x

    @staticmethod
    def get_8bit_vector_wise(x, dim, stochastic=False):
        round_func = LinearFunction.round_stoachastic if stochastic else torch.round
        max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
        max1[max1==0] = 1.0
        x = (x*127)/max1
        x = round_func(x)/127*max1
        return x

    @staticmethod
    def round_stoachastic(x):
        sign = torch.sign(x)
        absx = torch.abs(x)
        decimal = absx-torch.floor(absx)
        rdm = torch.rand_like(decimal)
        return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))

    @staticmethod
    def fake_8bit_storage(w, exponent_bits):
        code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_quantile(w, args):
        code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
        #C = bnb.functional.quantize_no_absmax(code, w)
        #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
        #print(out)
        #out = out.half()
        code /= torch.max(torch.abs(code))
        absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
        out = bnb.functional.dequantize_blockwise(absmax, C, code)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_stoachstic(w):
        rand = torch.rand(1024, device=w.device)
        absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
        out = bnb.functional.dequantize_blockwise(absmax, C)
        out = out.half()
        w.copy_(out)
        return out

    @staticmethod
    def fake_8bit_storage_with_max(w, topk=8):
        blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
        max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
        idx = idx[:, :topk]
        max_val = max_val[:, :topk]

        mask = torch.zeros_like(blocked_w)
        mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
        mask = mask.bool()

        # 1. zero out max values
        # 2. quantize + dequantize
        # 3. write back max values
        # 4. copy matrix back to weight

        values = blocked_w[mask]
        blocked_w[mask] = 0

        code = bnb.functional.create_dynamic_map()
        code = code.to(w.device)
        absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
        bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)

        blocked_w[mask] = values

        unblocked_w = blocked_w.flatten().view(w.shape)

        w.copy_(unblocked_w)
        return unblocked_w


    @staticmethod
    def forward(ctx, x, weight, bias=None, args=None):
        if args.use_8bit_training != 'off':
            weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
            outputq = bnb.functional.igemm(x8, weight8.t())
            output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
            #if torch.rand(1) < 0.01:
                #output32 = torch.matmul(x, weight.t())
                #err = torch.abs(output-output32).float()
                #relerr = err/(torch.abs(output32).float()+1e-8)
                #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
        else:
            #output = torch.matmul(x, weight.t())
            output = torch.einsum('bsi,oi->bso', x, weight)

        ctx.save_for_backward(x, weight, bias)
        ctx.args = args

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        x, weight, bias = ctx.saved_tensors
        args = ctx.args
        stochastic = False
        grad_input = grad_weight = grad_bias = None
        if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)

        # weight and x are already 8bit
        # -> transform grad_output to 8-bit
        if args.use_8bit_training == 'forward+wgrad':
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = bnb.functional.igemm(grad_output8, x8)
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)

            #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)

            grad_input = grad_output.matmul(weight)
        elif args.use_8bit_training == 'full':
            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
            x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
            grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
            bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
            grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)

            grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
            weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
            grad_input8 = bnb.functional.igemm(grad_output8, weight8)
            grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)

        else:
            grad_input = grad_output.matmul(weight)
            grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)

        return grad_input, grad_weight, grad_bias, None

class Linear8bit(nn.Module):
    def __init__(self, input_features, output_features, bias=True, args=None):
        super(Linear8bit, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.args = args

        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            self.register_parameter('bias', None)

        torch.nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            torch.nn.init.zeros_(self.bias)

    def forward(self, x):
        self.args.training = self.training

        return LinearFunction.apply(x, self.weight, self.bias, self.args)



def test_linear8bit():
    l0 = torch.nn.Linear(32, 64).cuda().half()
    l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
    l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
    l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()

    l0.weight.data = l2.weight.data.clone()
    l0.bias.data = l2.bias.data.clone()

    l1.weight.data = l2.weight.data.clone()
    l1.bias.data = l2.bias.data.clone()

    l3.weight.data = l2.weight.data.clone()
    l3.bias.data = l2.bias.data.clone()

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        t = torch.randn(16, 8, 64, device='cuda').half()
        b2 = b1.clone()
        b3 = b1.clone()
        b0 = b1.clone()

        o0 = l0(b0)
        o1 = l1(b1)
        o2 = l2(b2)
        o3 = l3(b3)

        assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
        assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)

        loss0 = torch.nn.functional.mse_loss(o0, t)
        loss1 = torch.nn.functional.mse_loss(o1, t)
        loss2 = torch.nn.functional.mse_loss(o2, t)
        loss3 = torch.nn.functional.mse_loss(o3, t)

        loss0.backward()
        loss1.backward()
        loss2.backward()
        loss3.backward()

        assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
        assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
        assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
        assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)

        err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
        err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
        err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()

        assert err1*0.8 < err2
        assert err2*0.8 < err3
        assert err3*0.8 < err1

        l0.weight.grad = None
        l1.weight.grad = None
        l2.weight.grad = None
        l3.weight.grad = None
        l0.bias.grad = None
        l1.bias.grad = None
        l2.bias.grad = None
        l3.bias.grad = None


threshold = [0.0, 3.0]
values = threshold
names = ['threshold_{0}'.format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
    l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
    assert l1.weight.device.type == 'cuda'
    assert l1.weight.dtype == torch.float16

    l1.eval()
    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = l1(b1)
        if i == 1:
            assert l1.state.CxB is not None

def test_linear8bitlt_accumulated_gradient():
    l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
    l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
    l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
    l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
    l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
    l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
    opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
    opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)

    acc_steps = 10


    for i in range(10):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = l1(b1)
        o2 = l2(b1)
        loss1 = o1.mean()
        loss2 = o2.mean()
        loss1.backward()
        loss2.backward()
        if i == 2:
            assert l1[0].state.CxB is not None
            assert l1[1].state.CxB is not None

        if i > 0 and i % acc_steps == 0:
            opt1.step()
            opt1.zero_grad(True)
            opt2.step()
            opt2.zero_grad(True)
            assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
            assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
            # we do this copy because otherwise we have small divergences over time that add up
            l1[0].weight.data.copy_(l2[0].weight.data)
            l1[1].weight.data.copy_(l2[1].weight.data)
        else:
            torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
            torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)


threshold = [0.0, 2.0]
values = threshold
names = ['threshold_{0}'.format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
    l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
    assert l1.weight.dtype == torch.int8

    l1.eval()
    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = l1(b1)
        assert o1.dtype == torch.float16

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        if threshold > 0: assert mlp.fc1.state.idx is not None
        if threshold > 0: assert mlp.fc2.state.idx is not None

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        if threshold > 0: assert mlp.fc1.state.idx is not None
        if threshold > 0: assert mlp.fc2.state.idx is not None

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        if threshold > 0: assert mlp.fc1.state.idx is not None
        if threshold > 0: assert mlp.fc2.state.idx is not None
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8


    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        if threshold > 0: assert mlp.fc1.state.idx is not None
        if threshold > 0: assert mlp.fc2.state.idx is not None
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
    assert mlp.fc1.weight.device.type == 'cuda'
    assert mlp.fc2.weight.device.type == 'cuda'

    mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')

    for i in range(100):
        b1 = torch.randn(16, 8, 32, device='cuda').half()
        o1 = mlp(b1)
        assert o1.dtype == torch.float16
        if threshold > 0: assert mlp.fc1.state.idx is not None
        if threshold > 0: assert mlp.fc2.state.idx is not None
    assert mlp.fc1.weight.dtype == torch.int8
    assert mlp.fc2.weight.dtype == torch.int8
    assert mlp.fc1.weight.device.type == 'cuda'
    assert mlp.fc2.weight.device.type == 'cuda'