summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 23:51:28 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 23:51:28 +0300
commite9b87112eeaabe3dfb51bdf553abbb94d9093870 (patch)
tree00b0c90ddeeb39d52028f571952188197a963202 /bitsandbytes/autograd
parent56a074f6dc50ae923e7a810b7c2ca53cd2f6129e (diff)
un-fuse bias
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 7293637..538267b 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -316,15 +316,14 @@ class MatMul8bitLt(torch.autograd.Function):
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
- output = output.to(A_dtype)
-
+ delayed_bias = None
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
- output = output.to(A_dtype).add_(bias)
+ delayed_bias = bias
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
- output += torch.matmul(subA, state.subB)
+ output.addmm_(subA, state.subB)
# 5. Save state
ctx.state = state
@@ -341,6 +340,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
+ output = output.to(A_dtype)
+ if delayed_bias is not None:
+ output.add_(delayed_bias)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))