From 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Tue, 23 Aug 2022 23:39:54 +0300 Subject: add memory efficient backward --- bitsandbytes/autograd/_functions.py | 39 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 20 deletions(-) (limited to 'bitsandbytes/autograd') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 4dbf129..63e8ad5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,11 +245,10 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - else: - if state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + elif state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -280,12 +279,6 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -343,12 +336,9 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() @@ -365,11 +355,20 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + if state.CxBt is None and state.has_fp16_weights: + CBt = state.CBt + elif state.CxBt is None: + assert state.CBt is None + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) if req_gradBias: -- cgit v1.2.3