diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:59:36 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:59:36 +0300 |
commit | 647c976a74249d284b31e8403dfcbcbfa3e203a3 (patch) | |
tree | 082bc2344f5f2a8fa1faad8cce060c522d732d34 | |
parent | 0de1a4494bd9246e5b1b3f2c7a0e4d4181fc644a (diff) |
change order
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 34b27d9..25ff1a5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,10 +316,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - delayed_bias = None + output = output.to(A_dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - delayed_bias = bias + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -340,9 +340,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - output = output.to(A_dtype) - if delayed_bias is not None: - output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) |