From e35e2c665a69647d829c48e22fba0230180c11e7 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:35:03 +0300 Subject: cast properly --- bitsandbytes/autograd/_functions.py | 2 +- tests/test_autograd.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5a83dfd..36c392b 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -231,7 +231,7 @@ class MatMul8bitLt(torch.autograd.Function): # Cast A to fp16 if A.dtype != torch.float16: - warnings.warn(f"MatMul8bitLt: input matrix will be cast from {A.dtype} to float16") + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 5171c4f..4e4282a 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -372,8 +372,10 @@ def test_matmullt( n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') + out_error_rate = 0.0175 if dtype == torch.float16 else 0.02 + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() <= n * 0.0175 + assert (idx == 0).sum().item() <= n * out_error_rate idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) assert (idx == 0).sum().item() <= n * 0.001 -- cgit v1.2.3