From de354f7ded52bfa857089769225cdf1ee694bfd6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 12:00:54 -0700 Subject: Added fused bias to matmullt. --- bitsandbytes/nn/modules.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 454dba5..24ecf39 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear): if self.weight.CB is not None: self.init_8bit_state() + if self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() # assert not self.state.has_fp16_weights # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None - out = bnb.matmul(x, self.weight, state=self.state) - - if self.bias is not None: - out += self.bias.unsqueeze(0).expand_as(out) + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass -- cgit v1.2.3