summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py40
1 files changed, 23 insertions, 17 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 63e8ad5..8ce1e60 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
- elif state.CxB is None:
- # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
- # we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ else:
+ if state.CxB is None:
+ # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+ # we also need to convert it to the turing/ampere format
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32")
- if state.CxBt is None and state.has_fp16_weights:
- CBt = state.CBt
- elif state.CxBt is None:
- assert state.CBt is None
- CB = state.CB.half()
- SCB = state.SCB.unsquezee(1).half()
- SCBt = state.SCBt.unsquezee(1).half()
- Bt = (CB * SCB).t().contiguous()
- CBt = (Bt / SCBt).t().to(torch.int8)
-
- CxBt, SBt = F.transform(
- CBt, to_order=formatB, transpose=True
- )
+ if state.CxBt is None:
+ if state.has_fp16_weights:
+ CBt = state.CBt
+ else:
+ # Restore CBt from CB
+ assert state.CBt is None, "CBt should not be stored in state"
+ CB = state.CB.half()
+ SCB = state.SCB.unsquezee(1).half()
+ SCBt = state.SCBt.unsquezee(1).half()
+ Bt = (CB * SCB).t().contiguous()
+ CBt = (Bt / SCBt).t().to(torch.int8)
+
+ # intentionally, do not store CxBt into state
+ CxBt, SBt = F.transform(
+ CBt, to_order=formatB, transpose=True
+ )
+ else:
+ CxBt = state.CxBt
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)