summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn/modules.py
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-09-11 05:51:29 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-09-11 05:51:29 +0300
commit42b5fc9acc4b59a6d90c662eb26099ac25907c7f (patch)
treedf0f65f65e2f1aae25462da1be9c65ca3fe45580 /bitsandbytes/nn/modules.py
parent843ad0631c65eabc7f64e80906ecf5482cc1a036 (diff)
add memory effcient backward option
Diffstat (limited to 'bitsandbytes/nn/modules.py')
-rw-r--r--bitsandbytes/nn/modules.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 3e32c8e..00d0c61 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear):
has_fp16_weights=True,
threshold=0.0,
index=None,
+ memory_efficient_backward=False
):
super(Linear8bitLt, self).__init__(
input_features, output_features, bias
@@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear):
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
+ self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
@@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CxB is not None:
- # In this version, we convert 8-bit row major to turing/ampere format at each inference pass
- # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place.
- del self.state.CxB
+ if not self.state.has_fp16_weights:
+ if not self.state.memory_efficient_backward and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+ elif self.state.memory_efficient_backward and self.state.CxB is not None:
+ # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
+ # Thus, we delete CxB from the state.
+ del self.state.CxB
return out