summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 23:44:28 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 23:44:28 +0300
commitd9ca0ed9051a21295e9be80ec08a6589ebd98222 (patch)
tree979c64dc4e3c84df837f2955178d4d4ad75d2e43 /bitsandbytes
parent7facedda38da928843e9ed0de1810d45ce1b9224 (diff)
un-fuse bias
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/autograd/_functions.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 6d9229b..540d1ec 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -234,8 +234,6 @@ class MatMul8bitLt(torch.autograd.Function):
if A_dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
A = A.to(torch.float16)
- if bias is not None:
- bias = bias.to(torch.float16)
# 1. Quantize A
if len(A.shape) == 3:
@@ -315,7 +313,11 @@ class MatMul8bitLt(torch.autograd.Function):
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
- output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
+
+ fused_bias = bias if bias.dtype == torch.float16 else None
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
+ if fused_bias is None and bias is not None:
+ output.add_(bias.to(output.dtype))
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: