summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-08-24 01:33:04 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-08-24 01:33:04 +0300
commitef2936a90d903d0f9a27e16ecb7f839f2c4d9ba1 (patch)
tree921b9dac7291900dff8d8f6bac0d125e34cace82
parent876387dc0c1c71ad9cd827d4aecc31190313c7ab (diff)
delete CxB from state
-rw-r--r--bitsandbytes/nn/modules.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 360a182..03ffd3b 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -260,11 +260,10 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- # if not self.state.has_fp16_weights and self.state.CB is not None:
- # we converted 8-bit row major to turing/ampere format in the first inference pass
- # we no longer need the row-major weight
- # del self.state.CB
- # self.weight.data = self.state.CxB
+ if not self.state.has_fp16_weights and self.state.CxB is not None:
+ # In this version, we convert 8-bit row major to turing/ampere format at each inference pass
+ # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
+ del self.state.CxB
return out