diff options
-rw-r--r-- | tests/test_autograd.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 636fe86..083d465 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2.half()) + ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: @@ -367,6 +367,8 @@ def test_matmullt( if has_bias: out_torch += bias + assert out_bnb.dtype == torch.dtype + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') |