diff options
author | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-08-23 23:51:00 +0300 |
---|---|---|
committer | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-08-23 23:51:00 +0300 |
commit | 1753aa04185b10a3bb52f7289ed4af15cf2502a7 (patch) | |
tree | 440cb8d367ffc5fa5c43e5c94cf038374d87907d /bitsandbytes/autograd | |
parent | 8ae9bb23ad9c61a92ab1a0ac6be65cd787c4fe5b (diff) |
refactoring
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 40 |
1 files changed, 23 insertions, 17 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 63e8ad5..8ce1e60 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function): subA = A[:, idx] state.subB = B[:, idx].t().contiguous() state.idx = idx - elif state.CxB is None: - # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function): if req_gradA: C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None and state.has_fp16_weights: - CBt = state.CBt - elif state.CxBt is None: - assert state.CBt is None - CB = state.CB.half() - SCB = state.SCB.unsquezee(1).half() - SCBt = state.SCBt.unsquezee(1).half() - Bt = (CB * SCB).t().contiguous() - CBt = (Bt / SCBt).t().to(torch.int8) - - CxBt, SBt = F.transform( - CBt, to_order=formatB, transpose=True - ) + if state.CxBt is None: + if state.has_fp16_weights: + CBt = state.CBt + else: + # Restore CBt from CB + assert state.CBt is None, "CBt should not be stored in state" + CB = state.CB.half() + SCB = state.SCB.unsquezee(1).half() + SCBt = state.SCBt.unsquezee(1).half() + Bt = (CB * SCB).t().contiguous() + CBt = (Bt / SCBt).t().to(torch.int8) + + # intentionally, do not store CxBt into state + CxBt, SBt = F.transform( + CBt, to_order=formatB, transpose=True + ) + else: + CxBt = state.CxBt gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) |