From 4d6174bc6336fb6fba712f1d2c903de1de677747 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Thu, 25 Aug 2022 19:09:23 +0300 Subject: memory efficient fp16 backward --- bitsandbytes/autograd/_functions.py | 40 +++++-------------------------------- bitsandbytes/nn/modules.py | 7 +------ 2 files changed, 6 insertions(+), 41 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7cf4999..52e56d0 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -196,7 +196,6 @@ class MatmulLtState: self.CxBt = None self.SBt = None - self.CBt = None class MatMul8bitLt(torch.autograd.Function): @@ -327,15 +326,12 @@ class MatMul8bitLt(torch.autograd.Function): #clone_func = torch.clone return clone_func(output.view(output_shape)) - @staticmethod def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, req_gradBias = ctx.req_grads - CAt, subA = ctx.tensors - SCAt, idx = ctx.tensor_states - formatB = ctx.formatB + assert not req_gradB, "TODO: support weight updates as well" state = ctx.state if len(grad_output.shape) == 3: @@ -345,37 +341,11 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = grad_B = grad_bias = None - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) - if req_gradB: - CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - if state.has_fp16_weights: - CBt = state.CBt - else: - # Restore CBt from CB - assert state.CBt is None, "CBt should not be stored in state" - CB = state.CB.half() - SCB = state.SCB.unsqueeze(1).half() - SCBt = state.SCBt.unsqueeze(1).half() - Bt = (CB * SCB).t().contiguous() - CBt = (Bt / SCBt).t().to(torch.int8) - - # intentionally, do not store CxBt in state - CxBt, SBt = F.transform( - CBt, to_order=formatB, transpose=True - ) - else: - CxBt = state.CxBt - gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + CB = state.CB.half() + SCB = state.SCB.unsqueeze(1).half() + B = (CB * SCB) / 127.0 + grad_A = torch.mm(grad_output, B).view(ctx.grad_shape) if req_gradBias: grad_bias = grad_output.sum(0) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 03ffd3b..3e32c8e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,12 +148,10 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, - SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None - cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -167,10 +165,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt + del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) - setattr(self, "SCBt", SCBt) return self @@ -212,7 +210,6 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB - new_param.SCBt = self.SCBt return new_param @@ -243,10 +240,8 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB - self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None - self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training -- cgit v1.2.3