summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 538267b..34b27d9 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
+ grad_A = grad_B = grad_bias = None
+
+ if req_gradBias:
+ # compute grad_bias first before changing grad_output dtype
+ grad_bias = grad_output.sum(0)
# Cast grad_output to fp16
grad_output_dtype = grad_output.dtype
@@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function):
-1, grad_output.shape[-1]
).contiguous()
- grad_A = grad_B = grad_bias = None
-
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
@@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function):
else:
raise Exception('State must contain either CBt or CB matrix for backward')
- if req_gradBias:
- grad_bias = grad_output.sum(0)
-
# Cast grad_A back to grad_output_dtype
grad_output = grad_output.to(grad_output_dtype)