From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- tests/test_modules.py | 297 +++++++++++++++++++++++++++++--------------------- 1 file changed, 175 insertions(+), 122 deletions(-) (limited to 'tests/test_modules.py') diff --git a/tests/test_modules.py b/tests/test_modules.py index a2c950b..6b8d641 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,21 +1,27 @@ +from itertools import product + import pytest import torch - -from itertools import product from torch import nn import bitsandbytes as bnb + class MockArgs(object): def __init__(self, initial_data): for key in initial_data: setattr(self, key, initial_data[key]) + class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): super(MLP8bit, self).__init__() - self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) - self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) + self.fc1 = bnb.nn.Linear8bitLt( + dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold + ) + self.fc2 = bnb.nn.Linear8bitLt( + dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold + ) def forward(self, x): x = self.fc1(x) @@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module): def get_args(): args = MockArgs([]) - args.quant_type = 'vector' - args.use_8bit_training = 'full' + args.quant_type = "vector" + args.use_8bit_training = "full" args.clip_freq = 9999 return args + def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): idx = torch.isclose(a, b, rtol, atol) - sumval = (idx==0).sum().item() + sumval = (idx == 0).sum().item() if sumval > count: - print(f'Too many values not close: assert {sumval} < {count}') + print(f"Too many values not close: assert {sumval} < {count}") torch.testing.assert_allclose(a, b, rtol, atol) -class LinearFunction(torch.autograd.Function): +class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): round_func = LinearFunction.round_stoachastic if stochastic else torch.round - norm = math.sqrt(math.pi)/math.sqrt(2.0) - #std = torch.abs(x).mean()*norm + norm = math.sqrt(math.pi) / math.sqrt(2.0) + # std = torch.abs(x).mean()*norm std = torch.std(x) - max1 = std*trim_value - x = x/max1*127 + max1 = std * trim_value + x = x / max1 * 127 x = round_func(x) x[x > 127] = 127 x[x < -127] = -127 - x = x/127*max1 + x = x / 127 * max1 return x def quant(x, quant_type, dim=1): - if quant_type == 'linear': + if quant_type == "linear": max1 = torch.abs(x).max().float() - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type == 'vector': + elif quant_type == "vector": max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type == 'min-max': + elif quant_type == "min-max": maxA = torch.amax(x, dim=dim, keepdim=True).float() minA = torch.amin(x, dim=dim, keepdim=True).float() - scale = (maxA-minA)/2.0 - xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) + scale = (maxA - minA) / 2.0 + xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8) return xq, (minA.float(), scale.float()) - else: return None + else: + return None def dequant(xq, S1, S2, dtype, quant_type): - if quant_type == 'linear': - norm = S1*S2/(127*127) + if quant_type == "linear": + norm = S1 * S2 / (127 * 127) # double cast needed to prevent overflows - return (xq.float()*norm).to(dtype) - elif quant_type == 'vector': + return (xq.float() * norm).to(dtype) + elif quant_type == "vector": x = xq.float() - if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) - if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) - #print(x.shape, S1.shape, S2.shape) + if len(xq.shape) == 2 and len(S1.shape) == 3: + S1 = S1.squeeze(0) + if len(xq.shape) == 2 and len(S2.shape) == 3: + S2 = S2.squeeze(0) + # print(x.shape, S1.shape, S2.shape) if len(S1.shape) == 2: - x *= S1.t()/127 + x *= S1.t() / 127 else: - x *= S1/127 - x *= S2/127 + x *= S1 / 127 + x *= S2 / 127 return x.to(dtype) - else: return None + else: + return None def dequant_min_max(xq, A, B, SA, SB, dtype): - offset = B.float().t().sum(0)*(SA[0]+SA[1]) + offset = B.float().t().sum(0) * (SA[0] + SA[1]) x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) - if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SA.shape) == 3: + SA = SA.squeeze(0) if len(SB.shape) == 2: - x *= SB.t()/127 + x *= SB.t() / 127 else: - x *= SB/127 - x *= SA[1]/127 - x +=offset + x *= SB / 127 + x *= SA[1] / 127 + x += offset return x.to(dtype) - def get_8bit_linear(x, stochastic=False): round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.abs(x).max() - x = x/max1*127 - x = round_func(x)/127*max1 - #x = torch.round(x)/128*max1 + x = x / max1 * 127 + x = round_func(x) / 127 * max1 + # x = torch.round(x)/128*max1 return x @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - max1[max1==0] = 1.0 - x = (x*127)/max1 - x = round_func(x)/127*max1 + max1[max1 == 0] = 1.0 + x = (x * 127) / max1 + x = round_func(x) / 127 * max1 return x @staticmethod def round_stoachastic(x): sign = torch.sign(x) absx = torch.abs(x) - decimal = absx-torch.floor(absx) + decimal = absx - torch.floor(absx) rdm = torch.rand_like(decimal) - return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) + return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype)) @staticmethod def fake_8bit_storage(w, exponent_bits): @@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function): @staticmethod def fake_8bit_storage_quantile(w, args): code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) - #C = bnb.functional.quantize_no_absmax(code, w) - #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) - #print(out) - #out = out.half() + # C = bnb.functional.quantize_no_absmax(code, w) + # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) + # print(out) + # out = out.half() code /= torch.max(torch.abs(code)) absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) out = bnb.functional.dequantize_blockwise(absmax, C, code) @@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function): @staticmethod def fake_8bit_storage_with_max(w, topk=8): - blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) + blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256) max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) idx = idx[:, :topk] max_val = max_val[:, :topk] @@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function): w.copy_(unblocked_w) return unblocked_w - @staticmethod def forward(ctx, x, weight, bias=None, args=None): - if args.use_8bit_training != 'off': + if args.use_8bit_training != "off": weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) - #if torch.rand(1) < 0.01: - #output32 = torch.matmul(x, weight.t()) - #err = torch.abs(output-output32).float() - #relerr = err/(torch.abs(output32).float()+1e-8) - #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) + # if torch.rand(1) < 0.01: + # output32 = torch.matmul(x, weight.t()) + # err = torch.abs(output-output32).float() + # relerr = err/(torch.abs(output32).float()+1e-8) + # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) else: - #output = torch.matmul(x, weight.t()) - output = torch.einsum('bsi,oi->bso', x, weight) + # output = torch.matmul(x, weight.t()) + output = torch.einsum("bsi,oi->bso", x, weight) ctx.save_for_backward(x, weight, bias) ctx.args = args @@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function): args = ctx.args stochastic = False grad_input = grad_weight = grad_bias = None - if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) # weight and x are already 8bit # -> transform grad_output to 8-bit - if args.use_8bit_training == 'forward+wgrad': - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + if args.use_8bit_training == "forward+wgrad": + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) - #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) + # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) grad_input = grad_output.matmul(weight) - elif args.use_8bit_training == 'full': - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + elif args.use_8bit_training == "full": + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) + grad_input = LinearFunction.dequant( + grad_input8, S1, S3, grad_output.dtype, args.quant_type + ) else: grad_input = grad_output.matmul(weight) - grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) + grad_weight = torch.einsum("bsi,bso->oi", x, grad_output) return grad_input, grad_weight, grad_bias, None + class Linear8bit(nn.Module): def __init__(self, input_features, output_features, bias=True, args=None): super(Linear8bit, self).__init__() @@ -263,7 +286,7 @@ class Linear8bit(nn.Module): if bias: self.bias = nn.Parameter(torch.empty(output_features)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) torch.nn.init.xavier_uniform_(self.weight) if self.bias is not None: @@ -275,12 +298,11 @@ class Linear8bit(nn.Module): return LinearFunction.apply(x, self.weight, self.bias, self.args) - def test_linear8bit(): l0 = torch.nn.Linear(32, 64).cuda().half() - l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() + l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half() l2 = Linear8bit(32, 64, args=get_args()).cuda().half() - l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() + l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half() l0.weight.data = l2.weight.data.clone() l0.bias.data = l2.bias.data.clone() @@ -292,8 +314,8 @@ def test_linear8bit(): l3.bias.data = l2.bias.data.clone() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() - t = torch.randn(16, 8, 64, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() + t = torch.randn(16, 8, 64, device="cuda").half() b2 = b1.clone() b3 = b1.clone() b0 = b1.clone() @@ -318,16 +340,20 @@ def test_linear8bit(): assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) - assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) - assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + assert_all_approx_close( + l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 + ) + assert_all_approx_close( + l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 + ) - err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() - err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() - err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() + err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item() + err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item() + err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item() - assert err1*0.8 < err2 - assert err2*0.8 < err3 - assert err3*0.8 < err1 + assert err1 * 0.8 < err2 + assert err2 * 0.8 < err3 + assert err3 * 0.8 < err1 l0.weight.grad = None l1.weight.grad = None @@ -341,23 +367,28 @@ def test_linear8bit(): threshold = [0.0, 3.0] values = threshold -names = ['threshold_{0}'.format(vals) for vals in values] +names = ["threshold_{0}".format(vals) for vals in values] + + @pytest.mark.parametrize("threshold", values, ids=names) def test_linear8bitlt_inference(threshold): - l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() - assert l1.weight.device.type == 'cuda' + l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() + assert l1.weight.device.type == "cuda" assert l1.weight.dtype == torch.float16 l1.eval() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) if i == 1: assert l1.state.CxB is not None + def test_linear8bitlt_accumulated_gradient(): - l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) - l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) + l1 = torch.nn.Sequential( + *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] + ) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) @@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient(): acc_steps = 10 - for i in range(10): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) o2 = l2(b1) loss1 = o1.mean() @@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient(): opt1.zero_grad(True) opt2.step() opt2.zero_grad(True) - assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) - assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close( + l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 + ) + assert_all_approx_close( + l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 + ) # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) @@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient(): threshold = [0.0, 2.0] values = threshold -names = ['threshold_{0}'.format(vals) for vals in values] +names = ["threshold_{0}".format(vals) for vals in values] + + @pytest.mark.parametrize("threshold", values, ids=names) def test_linear8bitlt_no_fp16_weights(threshold): - l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() + l1 = ( + bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) + .cuda() + .half() + ) assert l1.weight.dtype == torch.int8 l1.eval() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) assert o1.dtype == torch.float16 @@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda') + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda") for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - assert mlp.fc1.weight.device.type == 'cuda' - assert mlp.fc2.weight.device.type == 'cuda' + assert mlp.fc1.weight.device.type == "cuda" + assert mlp.fc2.weight.device.type == "cuda" - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .to(torch.float16) + .to("cuda") + ) for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - assert mlp.fc1.weight.device.type == 'cuda' - assert mlp.fc2.weight.device.type == 'cuda' + assert mlp.fc1.weight.device.type == "cuda" + assert mlp.fc2.weight.device.type == "cuda" -- cgit v1.2.3