diff options
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r-- | bitsandbytes/nn/modules.py | 11 |
1 files changed, 4 insertions, 7 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 454dba5..24ecf39 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear): if self.weight.CB is not None: self.init_8bit_state() + if self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() # assert not self.state.has_fp16_weights # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None - out = bnb.matmul(x, self.weight, state=self.state) - - if self.bias is not None: - out += self.bias.unsqueeze(0).expand_as(out) + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass |