summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 23:53:49 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 23:53:49 +0300
commit0de1a4494bd9246e5b1b3f2c7a0e4d4181fc644a (patch)
treed7177f5bb1b21a1a1d533bb9cddadfd528118985 /bitsandbytes/autograd/_functions.py
parente9b87112eeaabe3dfb51bdf553abbb94d9093870 (diff)
change order
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 538267b..34b27d9 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
+ grad_A = grad_B = grad_bias = None
+
+ if req_gradBias:
+ # compute grad_bias first before changing grad_output dtype
+ grad_bias = grad_output.sum(0)
# Cast grad_output to fp16
grad_output_dtype = grad_output.dtype
@@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function):
-1, grad_output.shape[-1]
).contiguous()
- grad_A = grad_B = grad_bias = None
-
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
@@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function):
else:
raise Exception('State must contain either CBt or CB matrix for backward')
- if req_gradBias:
- grad_bias = grad_output.sum(0)
-
# Cast grad_A back to grad_output_dtype
grad_output = grad_output.to(grad_output_dtype)