diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:44:28 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:44:28 +0300 |
commit | d9ca0ed9051a21295e9be80ec08a6589ebd98222 (patch) | |
tree | 979c64dc4e3c84df837f2955178d4d4ad75d2e43 /bitsandbytes | |
parent | 7facedda38da928843e9ed0de1810d45ce1b9224 (diff) |
un-fuse bias
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6d9229b..540d1ec 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -234,8 +234,6 @@ class MatMul8bitLt(torch.autograd.Function): if A_dtype != torch.float16: warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") A = A.to(torch.float16) - if bias is not None: - bias = bias.to(torch.float16) # 1. Quantize A if len(A.shape) == 3: @@ -315,7 +313,11 @@ class MatMul8bitLt(torch.autograd.Function): C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + + fused_bias = bias if bias.dtype == torch.float16 else None + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias) + if fused_bias is None and bias is not None: + output.add_(bias.to(output.dtype)) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: |