summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn/modules.py
diff options
context:
space:
mode:
authordbaranchuk <dmitrybaranchuk@gmail.com>2022-08-25 19:09:23 +0300
committerdbaranchuk <dmitrybaranchuk@gmail.com>2022-08-25 19:09:23 +0300
commit4d6174bc6336fb6fba712f1d2c903de1de677747 (patch)
tree5afe32e6fa6ad2e66019075be6dbc45430e98f35 /bitsandbytes/nn/modules.py
parentef2936a90d903d0f9a27e16ecb7f839f2c4d9ba1 (diff)
memory efficient fp16 backward
Diffstat (limited to 'bitsandbytes/nn/modules.py')
-rw-r--r--bitsandbytes/nn/modules.py7
1 files changed, 1 insertions, 6 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 03ffd3b..3e32c8e 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter):
has_fp16_weights=False,
CB=None,
SCB=None,
- SCBt=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
- cls.SCBt = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter):
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
+ del SCBt
self.data = CB
setattr(self, "CB", CB)
setattr(self, "SCB", SCB)
- setattr(self, "SCBt", SCBt)
return self
@@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter):
)
new_param.CB = self.CB
new_param.SCB = self.SCB
- new_param.SCBt = self.SCBt
return new_param
@@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear):
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
- self.state.SCBt = self.weight.SCBt
self.weight.CB = None
self.weight.SCB = None
- self.weight.SCBt = None
def forward(self, x):
self.state.is_training = self.training