diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:53:49 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:53:49 +0300 |
commit | 0de1a4494bd9246e5b1b3f2c7a0e4d4181fc644a (patch) | |
tree | d7177f5bb1b21a1a1d533bb9cddadfd528118985 /bitsandbytes/autograd | |
parent | e9b87112eeaabe3dfb51bdf553abbb94d9093870 (diff) |
change order
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 538267b..34b27d9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -357,6 +357,11 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0) # Cast grad_output to fp16 grad_output_dtype = grad_output.dtype @@ -367,8 +372,6 @@ class MatMul8bitLt(torch.autograd.Function): -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = grad_bias = None - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) @@ -395,9 +398,6 @@ class MatMul8bitLt(torch.autograd.Function): else: raise Exception('State must contain either CBt or CB matrix for backward') - if req_gradBias: - grad_bias = grad_output.sum(0) - # Cast grad_A back to grad_output_dtype grad_output = grad_output.to(grad_output_dtype) |