summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 20:46:04 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-17 20:46:04 +0300
commitcc4858c2fd48ef17a888b9d45bb35bb00e373eb8 (patch)
treee452b1e1ccbebd58e77cd4b0f6429366c670cd44
parent3634fc738bc20e4041c75544d3f678f61ce2348c (diff)
some kind of warning or something when this is first executed to make people aware that a cast happens and the operation quantization is performed in fp16.
-rw-r--r--bitsandbytes/autograd/_functions.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index bdcbec5..6d473e9 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -1,4 +1,6 @@
import operator
+import warnings
+
import torch
import bitsandbytes.functional as F
@@ -229,6 +231,8 @@ class MatMul8bitLt(torch.autograd.Function):
# Cast A to fp16
A_dtype = A.dtype
+ if A_dtype != torch.float16:
+ warnings.warn(f"MatMul8bitLt: temporarily casting input matrix from {A_dtype} to float16")
A = A.to(torch.float16)
# 1. Quantize A