diff options
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 271c690..008655d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -373,7 +373,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - if state.CBt: + if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: state.CxBt, state.SBt = F.transform( @@ -381,13 +381,13 @@ class MatMul8bitLt(torch.autograd.Function): ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - elif state.CB: + elif state.CB is not None: CB = state.CB.half() SCB = (state.SCB.unsqueeze(1) / 127.0).half() CB *= SCB grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) else: - raise Exception('State must contain either CBt or CB matrix') + raise Exception('State must contain either CBt or CB matrix for backward') if req_gradBias: grad_bias = grad_output.sum(0) |