diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 63e8ad5..8ce1e60 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - elif state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None and state.has_fp16_weights: - CBt = state.CBt - elif state.CxBt is None: - assert state.CBt is None - CB = state.CB.half() - SCB = state.SCB.unsquezee(1).half() - SCBt = state.SCBt.unsquezee(1).half() - Bt = (CB * SCB).t().contiguous() - CBt = (Bt / SCBt).t().to(torch.int8) - - CxBt, SBt = F.transform( - CBt, to_order=formatB, transpose=True - ) + if state.CxBt is None: + if state.has_fp16_weights: + CBt = state.CBt + else: + # Restore CBt from CB + assert state.CBt is None, "CBt should not be stored in state" + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + # intentionally, do not store CxBt into state + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + else: + CxBt = state.CxBt gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) |