diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-18 00:35:42 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-18 00:35:42 +0300 |
commit | cbfdf0b5efe4923ba4533c274ce83072b7e502b5 (patch) | |
tree | 824dc9ea7a8d7d6e5b4b48184ecdef6da8207339 | |
parent | e35e2c665a69647d829c48e22fba0230180c11e7 (diff) |
cast edge case
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 36c392b..d0e48b7 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B |