diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:51:28 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:51:28 +0300 |
commit | e9b87112eeaabe3dfb51bdf553abbb94d9093870 (patch) | |
tree | 00b0c90ddeeb39d52028f571952188197a963202 /bitsandbytes | |
parent | 56a074f6dc50ae923e7a810b7c2ca53cd2f6129e (diff) |
un-fuse bias
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7293637..538267b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -316,15 +316,14 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A_dtype) - + delayed_bias = None else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A_dtype).add_(bias) + delayed_bias = bias # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: - output += torch.matmul(subA, state.subB) + output.addmm_(subA, state.subB) # 5. Save state ctx.state = state @@ -341,6 +340,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + output = output.to(A_dtype) + if delayed_bias is not None: + output.add_(delayed_bias) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) |