summaryrefslogtreecommitdiff
path: root/bitsandbytes/autograd
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-18 00:35:42 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-18 00:35:42 +0300
commitcbfdf0b5efe4923ba4533c274ce83072b7e502b5 (patch)
tree824dc9ea7a8d7d6e5b4b48184ecdef6da8207339 /bitsandbytes/autograd
parente35e2c665a69647d829c48e22fba0230180c11e7 (diff)
cast edge case
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r--bitsandbytes/autograd/_functions.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 36c392b..d0e48b7 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
- return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
+ return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
else:
- return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
+ return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B