summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py50
1 files changed, 39 insertions, 11 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 6b8d641..7faadb8 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
@@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
max1 = torch.abs(x).max()
x = x / max1 * 127
x = round_func(x) / 127 * max1
@@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = (x * 127) / max1
@@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
- output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
+ output = LinearFunction.dequant(
+ outputq, S1, S2, x.dtype, args.quant_type
+ )
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
@@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=2
+ )
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(
@@ -338,8 +348,12 @@ def test_linear8bit():
loss2.backward()
loss3.backward()
- assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
- assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(
+ l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
+ )
+ assert_all_approx_close(
+ l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
+ )
assert_all_approx_close(
l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
@@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
- l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
+ l2 = torch.nn.Sequential(
+ *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
+ )
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0:
assert mlp.fc2.state.idx is not None
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .cuda()
+ .half()
+ )
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
@@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0:
assert mlp.fc2.state.idx is not None
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .half()
+ .cuda()
+ )
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .half()
+ .to("cuda")
+ )
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()