summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
authorTim Dettmers <TimDettmers@users.noreply.github.com>2022-09-19 21:09:25 -0700
committerGitHub <noreply@github.com>2022-09-19 21:09:25 -0700
commit439f2b0c10abd3e9aade386d92810b074c69e9ec (patch)
tree75454081c86ba1c96c07e83defc9fc5f4de840cf /bitsandbytes/nn
parent9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff)
parent76ce9aa6da7d68d2463f0f3e99532ab5b6db58a8 (diff)
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/modules.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index b222f54..9250fec 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear):
output_features,
bias=True,
has_fp16_weights=True,
+ memory_efficient_backward=False,
threshold=0.0,
index=None,
):
@@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear):
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
+ self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
- self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+ self.weight = Int8Params(
+ self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
+ )
def init_8bit_state(self):
self.state.CB = self.weight.CB
@@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CB is not None:
- # we converted 8-bit row major to turing/ampere format in the first inference pass
- # we no longer need the row-major weight
- del self.state.CB
- self.weight.data = self.state.CxB
+ if not self.state.has_fp16_weights:
+ if not self.state.memory_efficient_backward and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+ elif self.state.memory_efficient_backward and self.state.CxB is not None:
+ # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
+ # Thus, we delete CxB from the state.
+ del self.state.CxB
return out