diff options
author | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-09-11 06:18:44 +0300 |
---|---|---|
committer | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-09-11 06:18:44 +0300 |
commit | ee325f02157cd23b37059e3dce5fb17cb1c1b137 (patch) | |
tree | 8bd9e096a7f9f50ddd9ca02678a3a670db6cbd1f /bitsandbytes/autograd | |
parent | 42b5fc9acc4b59a6d90c662eb26099ac25907c7f (diff) |
clarified an exception message
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 271c690..008655d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -373,7 +373,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - if state.CBt: + if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: state.CxBt, state.SBt = F.transform( @@ -381,13 +381,13 @@ class MatMul8bitLt(torch.autograd.Function): ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - elif state.CB: + elif state.CB is not None: CB = state.CB.half() SCB = (state.SCB.unsqueeze(1) / 127.0).half() CB *= SCB grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) else: - raise Exception('State must contain either CBt or CB matrix') + raise Exception('State must contain either CBt or CB matrix for backward') if req_gradBias: grad_bias = grad_output.sum(0) |