summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-17 03:45:57 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-17 03:45:57 -0700
commit9d60b3c5279641ba936facd710c722ebe52fcf40 (patch)
treeafc62a7ae35224acba5a03150623ef5a82830599
parentb00cc9137fce41359318741106df92747aa14796 (diff)
Fixed bug in Linear8bitLt, when the bias is None.
-rw-r--r--bitsandbytes/nn/modules.py6
-rw-r--r--setup.py2
-rw-r--r--tests/test_modules.py23
3 files changed, 27 insertions, 4 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 24ecf39..b222f54 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -248,10 +248,10 @@ class Linear8bitLt(nn.Linear):
if self.weight.CB is not None:
self.init_8bit_state()
- if self.bias.dtype != torch.float16:
+
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
+ if self.bias is not None and self.bias.dtype != torch.float16:
self.bias.data = self.bias.data.half()
- # assert not self.state.has_fp16_weights
- # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
diff --git a/setup.py b/setup.py
index 2b25720..ef33f8a 100644
--- a/setup.py
+++ b/setup.py
@@ -18,7 +18,7 @@ def read(fname):
setup(
name=f"bitsandbytes",
- version=f"0.32.0",
+ version=f"0.32.1",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.",
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 7faadb8..c0b3311 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
+
+
+def test_linear8bitlt_fp32_bias():
+ # casts model to fp16 -> int8 automatically
+ l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
+ assert l1.weight.dtype == torch.int8
+ assert l1.bias.dtype == torch.float32
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
+ # casts bias to fp32
+ o1 = l1(b1)
+ assert l1.bias.dtype == torch.float16
+
+ # casts model to fp16 -> int8 automatically
+ l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
+ assert l1.weight.dtype == torch.int8
+ assert l1.bias is None
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device="cuda").half()
+ o1 = l1(b1)
+ assert l1.bias is None