From ee325f02157cd23b37059e3dce5fb17cb1c1b137 Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 06:18:44 +0300 Subject: clarified an exception message --- bitsandbytes/autograd/_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 271c690..008655d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -373,7 +373,7 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - if state.CBt: + if state.CBt is not None: C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: state.CxBt, state.SBt = F.transform( @@ -381,13 +381,13 @@ class MatMul8bitLt(torch.autograd.Function): ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) - elif state.CB: + elif state.CB is not None: CB = state.CB.half() SCB = (state.SCB.unsqueeze(1) / 127.0).half() CB *= SCB grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) else: - raise Exception('State must contain either CBt or CB matrix') + raise Exception('State must contain either CBt or CB matrix for backward') if req_gradBias: grad_bias = grad_output.sum(0) -- cgit v1.2.3