From 647c976a74249d284b31e8403dfcbcbfa3e203a3 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:59:36 +0300 Subject: change order --- bitsandbytes/autograd/_functions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 34b27d9..25ff1a5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,10 +316,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - delayed_bias = None + output = output.to(A_dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - delayed_bias = bias + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -340,9 +340,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - output = output.to(A_dtype) - if delayed_bias is not None: - output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) -- cgit v1.2.3