summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-18 00:07:05 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-18 00:07:05 +0300
commite2b523d071c1dfe70c274a7ff945e859bc8f9e02 (patch)
tree26253b43427c54b0d1d76ec734026a29a17fe8e0
parent85bf5294a60ceba84b85f0634b349bc486cec635 (diff)
change typecast behavior
-rw-r--r--bitsandbytes/autograd/_functions.py12
1 files changed, 5 insertions, 7 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 2aada07..6868b75 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function):
state.outlier_pool = GlobalOutlierPooler.get_instance()
# Cast A to fp16
- A_dtype = A.dtype
- if A_dtype != torch.float16:
- warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16")
- A = A.to(torch.float16)
+ if A.dtype != torch.float16:
+ warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16")
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
- A, threshold=state.threshold
+ A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None:
@@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function):
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
- output = output.to(A_dtype)
+ output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
- output = output.to(A_dtype).add_(bias)
+ output = output.to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: