diff options
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r-- | bitsandbytes/nn/modules.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3e32c8e..00d0c61 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear): has_fp16_weights=True, threshold=0.0, index=None, + memory_efficient_backward=False ): super(Linear8bitLt, self).__init__( input_features, output_features, bias @@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear): self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True @@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CxB is not None: - # In this version, we convert 8-bit row major to turing/ampere format at each inference pass - # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. - del self.state.CxB + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB return out |