From 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:39:54 +0300 Subject: add memory efficient backward --- bitsandbytes/autograd/_functions.py | 39 ++++++++++++++++++------------------- bitsandbytes/nn/modules.py | 13 +++++++++---- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 4dbf129..63e8ad5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + elif state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() @@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + if state.CxBt is None and state.has_fp16_weights: + CBt = state.CBt + elif state.CxBt is None: + assert state.CBt is None + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) if req_gradBias: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..ef7fefc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, + SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None + cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt - del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) + setattr(self, "SCBt", SCBt) return self @@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB + new_param.SCB = self.SCBt return new_param @@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB + self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None + self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training @@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: + # if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + # del self.state.CB + # self.weight.data = self.state.CxB return out -- cgit v1.2.3