summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 7faadb8..c0b3311 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
+
+
+def test_linear8bitlt_fp32_bias():
+ # casts model to fp16 -> int8 automatically
+ l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
+ assert l1.weight.dtype == torch.int8
+ assert l1.bias.dtype == torch.float32
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
+ # casts bias to fp32
+ o1 = l1(b1)
+ assert l1.bias.dtype == torch.float16
+
+ # casts model to fp16 -> int8 automatically
+ l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
+ assert l1.weight.dtype == torch.int8
+ assert l1.bias is None
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
+ o1 = l1(b1)
+ assert l1.bias is None