From 9d60b3c5279641ba936facd710c722ebe52fcf40 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 17 Aug 2022 03:45:57 -0700 Subject: Fixed bug in Linear8bitLt, when the bias is None. --- bitsandbytes/nn/modules.py | 6 +++--- setup.py | 2 +- tests/test_modules.py | 23 +++++++++++++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 24ecf39..b222f54 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -248,10 +248,10 @@ class Linear8bitLt(nn.Linear): if self.weight.CB is not None: self.init_8bit_state() - if self.bias.dtype != torch.float16: + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != torch.float16: self.bias.data = self.bias.data.half() - # assert not self.state.has_fp16_weights - # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) diff --git a/setup.py b/setup.py index 2b25720..ef33f8a 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.32.0", + version=f"0.32.1", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", diff --git a/tests/test_modules.py b/tests/test_modules.py index 7faadb8..c0b3311 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" + + +def test_linear8bitlt_fp32_bias(): + # casts model to fp16 -> int8 automatically + l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() + assert l1.weight.dtype == torch.int8 + assert l1.bias.dtype == torch.float32 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + # casts bias to fp32 + o1 = l1(b1) + assert l1.bias.dtype == torch.float16 + + # casts model to fp16 -> int8 automatically + l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() + assert l1.weight.dtype == torch.int8 + assert l1.bias is None + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = l1(b1) + assert l1.bias is None -- cgit v1.2.3