summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py8
-rw-r--r--tests/test_autograd.py2
2 files changed, 6 insertions, 4 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 6d9229b..540d1ec 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -234,8 +234,6 @@ class MatMul8bitLt(torch.autograd.Function):
if A_dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
A = A.to(torch.float16)
- if bias is not None:
- bias = bias.to(torch.float16)
# 1. Quantize A
if len(A.shape) == 3:
@@ -315,7 +313,11 @@ class MatMul8bitLt(torch.autograd.Function):
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
- output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
+
+ fused_bias = bias if bias.dtype == torch.float16 else None
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
+ if fused_bias is None and bias is not None:
+ output.add_(bias.to(output.dtype))
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 28d9259..5171c4f 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -427,4 +427,4 @@ def test_matmullt(
)
if req_grad[2]:
- torch.testing.assert_allclose(gradBias1, gradBias2, atol=0.18, rtol=0.3)
+ torch.testing.assert_allclose(gradBias1, gradBias2)