diff options
author | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-08-25 19:09:23 +0300 |
---|---|---|
committer | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-08-25 19:09:23 +0300 |
commit | 4d6174bc6336fb6fba712f1d2c903de1de677747 (patch) | |
tree | 5afe32e6fa6ad2e66019075be6dbc45430e98f35 /bitsandbytes/nn | |
parent | ef2936a90d903d0f9a27e16ecb7f839f2c4d9ba1 (diff) |
memory efficient fp16 backward
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r-- | bitsandbytes/nn/modules.py | 7 |
1 files changed, 1 insertions, 6 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 03ffd3b..3e32c8e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, - SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None - cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt + del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) - setattr(self, "SCBt", SCBt) return self @@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCBt = self.SCBt return new_param @@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB - self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None - self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training |