From e9b87112eeaabe3dfb51bdf553abbb94d9093870 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:51:28 +0300 Subject: un-fuse bias --- bitsandbytes/autograd/_functions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7293637..538267b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,15 +316,14 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A_dtype) - + delayed_bias = None else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A_dtype).add_(bias) + delayed_bias = bias # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) + output.addmm_(subA, state.subB) # 5. Save state ctx.state = state @@ -341,6 +340,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + output = output.to(A_dtype) + if delayed_bias is not None: + output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) -- cgit v1.2.3