summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd/_functions.py
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 23:46:37 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 23:46:37 +0300
commit56a074f6dc50ae923e7a810b7c2ca53cd2f6129e (patch)
tree3a90dc00ddf1b8ff0258ad674940ec317ae0479d /bitsandbytes/autograd/_functions.py
parentd9ca0ed9051a21295e9be80ec08a6589ebd98222 (diff)
un-fuse bias
Diffstat (limited to 'bitsandbytes/autograd/_functions.py')
-rw-r--r--bitsandbytes/autograd/_functions.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 540d1ec..7293637 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function):
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here
- fused_bias = bias if bias.dtype == torch.float16 else None
- output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=fused_bias)
- if fused_bias is None and bias is not None:
- output.add_(bias.to(output.dtype))
+ if bias is None or bias.dtype == torch.float16:
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
+ output = output.to(A_dtype)
+
+ else: # apply bias separately
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
+ output = output.to(A_dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
@@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
- # Cast fp16 output back to A.dtype
- output = output.to(A_dtype)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))