diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:46:37 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:46:37 +0300 |
commit | 56a074f6dc50ae923e7a810b7c2ca53cd2f6129e (patch) | |
tree | 3a90dc00ddf1b8ff0258ad674940ec317ae0479d /bitsandbytes | |
parent | d9ca0ed9051a21295e9be80ec08a6589ebd98222 (diff) |
un-fuse bias
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 540d1ec..7293637 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function): out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - fused_bias = bias if bias.dtype == torch.float16 else None - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias) - if fused_bias is None and bias is not None: - output.add_(bias.to(output.dtype)) + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A_dtype) + + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A_dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - # Cast fp16 output back to A.dtype - output = output.to(A_dtype) clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) |