summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/autograd/_functions.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index bdcbec5..6d473e9 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,6 @@
import operator
+import warnings
+
import torch
import bitsandbytes.functional as F
@@ -229,6 +231,8 @@ class MatMul8bitLt(torch.autograd.Function):
# Cast A to fp16
A_dtype = A.dtype
+ if A_dtype != torch.float16:
+ warnings.warn(f"MatMul8bitLt: temporarily casting input matrix from {A_dtype} to float16")
A = A.to(torch.float16)
# 1. Quantize A