summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 23:38:09 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 23:38:09 +0300
commiteac9aca460ee7afb6d0cbc61ae43a95120d34f29 (patch)
treeec0dd8c9865adbf14769b50fc85e2cd8bf26fc16 /bitsandbytes/autograd
parenta9fe0ff98c3293d972eb7a638b9887df0bc0d30d (diff)
cast bias too
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index dc79bb1..6d9229b 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -234,6 +234,8 @@ class MatMul8bitLt(torch.autograd.Function):
if A_dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
A = A.to(torch.float16)
+ if bias is not None:
+ bias = bias.to(torch.float16)
# 1. Quantize A
if len(A.shape) == 3: