summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/nn/modules.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 24ecf39..b222f54 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -248,10 +248,10 @@ class Linear8bitLt(nn.Linear):
if self.weight.CB is not None:
self.init_8bit_state()
- if self.bias.dtype != torch.float16:
+
+ # weights are cast automatically as Int8Params, but the bias has to be cast manually
+ if self.bias is not None and self.bias.dtype != torch.float16:
self.bias.data = self.bias.data.half()
- # assert not self.state.has_fp16_weights
- # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)