summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-08-23 23:39:54 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-08-23 23:39:54 +0300
commit8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b (patch)
treeb0b17700aad3ac18a1265e078c0ea6b1ada8b87f /bitsandbytes/nn
parent9d60b3c5279641ba936facd710c722ebe52fcf40 (diff)
add memory efficient backward
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/modules.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index b222f54..ef7fefc 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights=False,
CB=None,
SCB=None,
+ SCBt=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
+ cls.SCBt = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter):
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
- del SCBt
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
+ setattr(self, "SCBt", SCBt)
return self
@@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter):
)
new_param.CB = self.CB
new_param.SCB = self.SCB
+ new_param.SCB = self.SCBt
return new_param
@@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear):
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
+ self.state.SCBt = self.weight.SCBt
self.weight.CB = None
self.weight.SCB = None
+ self.weight.SCBt = None
def forward(self, x):
self.state.is_training = self.training
@@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
- if not self.state.has_fp16_weights and self.state.CB is not None:
+ # if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
- del self.state.CB
- self.weight.data = self.state.CxB
+ # del self.state.CB
+ # self.weight.data = self.state.CxB
return out