diff options
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 2aada07..6868b75 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - A_dtype = A.dtype - if A_dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: input matrix will be converted from {A_dtype} to float16") - A = A.to(torch.float16) + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16") # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A, threshold=state.threshold + A.to(torch.float16), threshold=state.threshold ) if state.threshold > 0.0 and coo_tensorA is not None: @@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function): if bias is None or bias.dtype == torch.float16: output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) - output = output.to(A_dtype) + output = output.to(A.dtype) else: # apply bias separately output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) - output = output.to(A_dtype).add_(bias) + output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: |