From 4d6174bc6336fb6fba712f1d2c903de1de677747 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Thu, 25 Aug 2022 19:09:23 +0300 Subject: memory efficient fp16 backward --- bitsandbytes/nn/modules.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 03ffd3b..3e32c8e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, - SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None - cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt + del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) - setattr(self, "SCBt", SCBt) return self @@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCBt = self.SCBt return new_param @@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB - self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None - self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training -- cgit v1.2.3