From 56a074f6dc50ae923e7a810b7c2ca53cd2f6129e Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:46:37 +0300 Subject: un-fuse bias --- bitsandbytes/autograd/_functions.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 540d1ec..7293637 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function): out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - fused_bias = bias if bias.dtype == torch.float16 else None - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias) - if fused_bias is None and bias is not None: - output.add_(bias.to(output.dtype)) + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A_dtype) + + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - # Cast fp16 output back to A.dtype - output = output.to(A_dtype) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) -- cgit v1.2.3