summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 23:59:36 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 23:59:36 +0300
commit647c976a74249d284b31e8403dfcbcbfa3e203a3 (patch)
tree082bc2344f5f2a8fa1faad8cce060c522d732d34
parent0de1a4494bd9246e5b1b3f2c7a0e4d4181fc644a (diff)
change order
-rw-r--r--bitsandbytes/autograd/_functions.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 34b27d9..25ff1a5 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -316,10 +316,10 @@ class MatMul8bitLt(torch.autograd.Function):
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
- delayed_bias = None
+ output = output.to(A_dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
- delayed_bias = bias
+ output = output.to(A_dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
@@ -340,9 +340,6 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
- output = output.to(A_dtype)
- if delayed_bias is not None:
- output.add_(delayed_bias)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape))