diff options
author | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-08-23 23:39:54 +0300 |
---|---|---|
committer | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-08-23 23:39:54 +0300 |
commit | 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b (patch) | |
tree | b0b17700aad3ac18a1265e078c0ea6b1ada8b87f /bitsandbytes | |
parent | 9d60b3c5279641ba936facd710c722ebe52fcf40 (diff) |
add memory efficient backward
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 39 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 13 |
2 files changed, 28 insertions, 24 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 4dbf129..63e8ad5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + elif state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() @@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + if state.CxBt is None and state.has_fp16_weights: + CBt = state.CBt + elif state.CxBt is None: + assert state.CBt is None + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) if req_gradBias: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..ef7fefc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -148,10 +148,12 @@ class Int8Params(torch.nn.Parameter): has_fp16_weights=False, CB=None, SCB=None, + SCBt=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None + cls.SCBt = None if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) @@ -165,10 +167,10 @@ class Int8Params(torch.nn.Parameter): B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt - del SCBt self.data = CB setattr(self, "CB", CB) setattr(self, "SCB", SCB) + setattr(self, "SCBt", SCBt) return self @@ -210,6 +212,7 @@ class Int8Params(torch.nn.Parameter): ) new_param.CB = self.CB new_param.SCB = self.SCB + new_param.SCB = self.SCBt return new_param @@ -240,8 +243,10 @@ class Linear8bitLt(nn.Linear): def init_8bit_state(self): self.state.CB = self.weight.CB self.state.SCB = self.weight.SCB + self.state.SCBt = self.weight.SCBt self.weight.CB = None self.weight.SCB = None + self.weight.SCBt = None def forward(self, x): self.state.is_training = self.training @@ -255,11 +260,11 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: + # if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + # del self.state.CB + # self.weight.data = self.state.CxB return out |