From e2b523d071c1dfe70c274a7ff945e859bc8f9e02 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:07:05 +0300 Subject: change typecast behavior --- bitsandbytes/autograd/_functions.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'bitsandbytes/autograd') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 2aada07..6868b75 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - A_dtype = A.dtype - if A_dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") - A = A.to(torch.float16) + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16") # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A, threshold=state.threshold + A.to(torch.float16), threshold=state.threshold ) if state.threshold > 0.0 and coo_tensorA is not None: @@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A_dtype) + output = output.to(A.dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A_dtype).add_(bias) + output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: -- cgit v1.2.3