summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/modules.py11
1 files changed, 4 insertions, 7 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 454dba5..24ecf39 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
- self.weight = Int8Params(
- self.weight.data, has_fp16_weights=has_fp16_weights
- )
+ self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
def init_8bit_state(self):
self.state.CB = self.weight.CB
@@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear):
if self.weight.CB is not None:
self.init_8bit_state()
+ if self.bias.dtype != torch.float16:
+ self.bias.data = self.bias.data.half()
# assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
- out = bnb.matmul(x, self.weight, state=self.state)
-
- if self.bias is not None:
- out += self.bias.unsqueeze(0).expand_as(out)
+ out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass