From c771b3a75a6ebbfbfc398a028a477246b0799cf0 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 14:41:05 -0700 Subject: Most tests passing. --- tests/test_modules.py | 478 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 453 insertions(+), 25 deletions(-) (limited to 'tests/test_modules.py') diff --git a/tests/test_modules.py b/tests/test_modules.py index a0379cb..a2c950b 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,42 +1,470 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import pytest import torch + +from itertools import product +from torch import nn + import bitsandbytes as bnb +class MockArgs(object): + def __init__(self, initial_data): + for key in initial_data: + setattr(self, key, initial_data[key]) + +class MLP8bit(torch.nn.Module): + def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + super(MLP8bit, self).__init__() + self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) + self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def get_args(): + args = MockArgs([]) + args.quant_type = 'vector' + args.use_8bit_training = 'full' + args.clip_freq = 9999 + return args + +def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx==0).sum().item() + if sumval > count: + print(f'Too many values not close: assert {sumval} < {count}') + torch.testing.assert_allclose(a, b, rtol, atol) + +class LinearFunction(torch.autograd.Function): + + @staticmethod + def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + norm = math.sqrt(math.pi)/math.sqrt(2.0) + #std = torch.abs(x).mean()*norm + std = torch.std(x) + max1 = std*trim_value + x = x/max1*127 + x = round_func(x) + x[x > 127] = 127 + x[x < -127] = -127 + x = x/127*max1 + + return x + + def quant(x, quant_type, dim=1): + if quant_type == 'linear': + max1 = torch.abs(x).max().float() + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type == 'vector': + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type == 'min-max': + maxA = torch.amax(x, dim=dim, keepdim=True).float() + minA = torch.amin(x, dim=dim, keepdim=True).float() + scale = (maxA-minA)/2.0 + xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) + return xq, (minA.float(), scale.float()) + else: return None + + def dequant(xq, S1, S2, dtype, quant_type): + if quant_type == 'linear': + norm = S1*S2/(127*127) + # double cast needed to prevent overflows + return (xq.float()*norm).to(dtype) + elif quant_type == 'vector': + x = xq.float() + if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) + if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) + #print(x.shape, S1.shape, S2.shape) + if len(S1.shape) == 2: + x *= S1.t()/127 + else: + x *= S1/127 + x *= S2/127 + return x.to(dtype) + else: return None + + def dequant_min_max(xq, A, B, SA, SB, dtype): + offset = B.float().t().sum(0)*(SA[0]+SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t()/127 + else: + x *= SB/127 + x *= SA[1]/127 + x +=offset + return x.to(dtype) + + + def get_8bit_linear(x, stochastic=False): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + max1 = torch.abs(x).max() + x = x/max1*127 + x = round_func(x)/127*max1 + #x = torch.round(x)/128*max1 + return x + + @staticmethod + def get_8bit_vector_wise(x, dim, stochastic=False): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1==0] = 1.0 + x = (x*127)/max1 + x = round_func(x)/127*max1 + return x + + @staticmethod + def round_stoachastic(x): + sign = torch.sign(x) + absx = torch.abs(x) + decimal = absx-torch.floor(absx) + rdm = torch.rand_like(decimal) + return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) + + @staticmethod + def fake_8bit_storage(w, exponent_bits): + code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_quantile(w, args): + code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) + #C = bnb.functional.quantize_no_absmax(code, w) + #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) + #print(out) + #out = out.half() + code /= torch.max(torch.abs(code)) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_stoachstic(w): + rand = torch.rand(1024, device=w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand) + out = bnb.functional.dequantize_blockwise(absmax, C) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_with_max(w, topk=8): + blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) + max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) + idx = idx[:, :topk] + max_val = max_val[:, :topk] + + mask = torch.zeros_like(blocked_w) + mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val)) + mask = mask.bool() + + # 1. zero out max values + # 2. quantize + dequantize + # 3. write back max values + # 4. copy matrix back to weight + + values = blocked_w[mask] + blocked_w[mask] = 0 + + code = bnb.functional.create_dynamic_map() + code = code.to(w.device) + absmax, C = bnb.functional.quantize_blockwise(blocked_w.data) + bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w) + + blocked_w[mask] = values + + unblocked_w = blocked_w.flatten().view(w.shape) + + w.copy_(unblocked_w) + return unblocked_w + + + @staticmethod + def forward(ctx, x, weight, bias=None, args=None): + if args.use_8bit_training != 'off': + weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) + outputq = bnb.functional.igemm(x8, weight8.t()) + output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) + #if torch.rand(1) < 0.01: + #output32 = torch.matmul(x, weight.t()) + #err = torch.abs(output-output32).float() + #relerr = err/(torch.abs(output32).float()+1e-8) + #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) + else: + #output = torch.matmul(x, weight.t()) + output = torch.einsum('bsi,oi->bso', x, weight) + + ctx.save_for_backward(x, weight, bias) + ctx.args = args + + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + args = ctx.args + stochastic = False + grad_input = grad_weight = grad_bias = None + if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) + + # weight and x are already 8bit + # -> transform grad_output to 8-bit + if args.use_8bit_training == 'forward+wgrad': + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = bnb.functional.igemm(grad_output8, x8) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + + #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) + + grad_input = grad_output.matmul(weight) + elif args.use_8bit_training == 'full': + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) + bnb.functional.igemm(grad_output8, x8, out=grad_weight8) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) + weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) + grad_input8 = bnb.functional.igemm(grad_output8, weight8) + grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) + + else: + grad_input = grad_output.matmul(weight) + grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) -@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) -def test_embeddings(embcls): - bnb.optim.GlobalOptimManager.get_instance().initialize() - emb1 = torch.nn.Embedding(100, 512).cuda() - emb2 = embcls(100, 512).cuda() + return grad_input, grad_weight, grad_bias, None - adam1 = bnb.optim.Adam8bit(emb1.parameters()) - adam2 = bnb.optim.Adam8bit(emb2.parameters()) +class Linear8bit(nn.Module): + def __init__(self, input_features, output_features, bias=True, args=None): + super(Linear8bit, self).__init__() + self.input_features = input_features + self.output_features = output_features + self.args = args - batches = torch.randint(1, 100, size=(100, 4, 32)).cuda() + self.weight = nn.Parameter(torch.empty(output_features, input_features)) + if bias: + self.bias = nn.Parameter(torch.empty(output_features)) + else: + self.register_parameter('bias', None) + torch.nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + torch.nn.init.zeros_(self.bias) + + def forward(self, x): + self.args.training = self.training + + return LinearFunction.apply(x, self.weight, self.bias, self.args) + + + +def test_linear8bit(): + l0 = torch.nn.Linear(32, 64).cuda().half() + l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() + l2 = Linear8bit(32, 64, args=get_args()).cuda().half() + l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() + + l0.weight.data = l2.weight.data.clone() + l0.bias.data = l2.bias.data.clone() + + l1.weight.data = l2.weight.data.clone() + l1.bias.data = l2.bias.data.clone() + + l3.weight.data = l2.weight.data.clone() + l3.bias.data = l2.bias.data.clone() + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + t = torch.randn(16, 8, 64, device='cuda').half() + b2 = b1.clone() + b3 = b1.clone() + b0 = b1.clone() + + o0 = l0(b0) + o1 = l1(b1) + o2 = l2(b2) + o3 = l3(b3) + + assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1) + assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1) + + loss0 = torch.nn.functional.mse_loss(o0, t) + loss1 = torch.nn.functional.mse_loss(o1, t) + loss2 = torch.nn.functional.mse_loss(o2, t) + loss3 = torch.nn.functional.mse_loss(o3, t) + + loss0.backward() + loss1.backward() + loss2.backward() + loss3.backward() + + assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + + err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() + err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() + err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() + + assert err1*0.8 < err2 + assert err2*0.8 < err3 + assert err3*0.8 < err1 + + l0.weight.grad = None + l1.weight.grad = None + l2.weight.grad = None + l3.weight.grad = None + l0.bias.grad = None + l1.bias.grad = None + l2.bias.grad = None + l3.bias.grad = None + + +threshold = [0.0, 3.0] +values = threshold +names = ['threshold_{0}'.format(vals) for vals in values] +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_inference(threshold): + l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() + assert l1.weight.device.type == 'cuda' + assert l1.weight.dtype == torch.float16 + + l1.eval() for i in range(100): - batch = batches[i] + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + if i == 1: + assert l1.state.CxB is not None + +def test_linear8bitlt_accumulated_gradient(): + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) + l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) + l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) + l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) + l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) + opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) + + acc_steps = 10 + - embedded1 = emb1(batch) - embedded2 = emb2(batch) + for i in range(10): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + o2 = l2(b1) + loss1 = o1.mean() + loss2 = o2.mean() + loss1.backward() + loss2.backward() + if i == 2: + assert l1[0].state.CxB is not None + assert l1[1].state.CxB is not None - l1 = embedded1.mean() - l2 = embedded2.mean() + if i > 0 and i % acc_steps == 0: + opt1.step() + opt1.zero_grad(True) + opt2.step() + opt2.zero_grad(True) + assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) + # we do this copy because otherwise we have small divergences over time that add up + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + else: + torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) + torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) - l1.backward() - l2.backward() - adam1.step() - adam2.step() +threshold = [0.0, 2.0] +values = threshold +names = ['threshold_{0}'.format(vals) for vals in values] +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_no_fp16_weights(threshold): + l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() + assert l1.weight.dtype == torch.int8 - adam1.zero_grad() - adam2.zero_grad() + l1.eval() + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + assert o1.dtype == torch.float16 + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 - assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 - assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda') + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == 'cuda' + assert mlp.fc2.weight.device.type == 'cuda' + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == 'cuda' + assert mlp.fc2.weight.device.type == 'cuda' -- cgit v1.2.3