From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- tests/test_modules.py | 50 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 11 deletions(-) (limited to 'tests/test_modules.py') diff --git a/tests/test_modules.py b/tests/test_modules.py index 6b8d641..7faadb8 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) norm = math.sqrt(math.pi) / math.sqrt(2.0) # std = torch.abs(x).mean()*norm std = torch.std(x) @@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function): return x.to(dtype) def get_8bit_linear(x, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) max1 = torch.abs(x).max() x = x / max1 * 127 x = round_func(x) / 127 * max1 @@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1[max1 == 0] = 1.0 x = (x * 127) / max1 @@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function): weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) + output = LinearFunction.dequant( + outputq, S1, S2, x.dtype, args.quant_type + ) # if torch.rand(1) < 0.01: # output32 = torch.matmul(x, weight.t()) # err = torch.abs(output-output32).float() @@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function): grad_weight8, S1, S2, grad_output.dtype, args.quant_type ) - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=2 + ) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) grad_input = LinearFunction.dequant( @@ -338,8 +348,12 @@ def test_linear8bit(): loss2.backward() loss3.backward() - assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) - assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close( + l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2 + ) + assert_all_approx_close( + l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2 + ) assert_all_approx_close( l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 ) @@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential( *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] ) - l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential( + *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] + ) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) @@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .cuda() + .half() + ) assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .half() + .cuda() + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda") + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .half() + .to("cuda") + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() -- cgit v1.2.3