diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 34b27d9..25ff1a5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,10 +316,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - delayed_bias = None + output = output.to(A_dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - delayed_bias = bias + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -340,9 +340,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - output = output.to(A_dtype) - if delayed_bias is not None: - output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) |