From 140cdbe8767247bb9b8ea510755cceaa304b6859 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sat, 17 Sep 2022 23:12:58 +0300 Subject: check dtypes first --- tests/test_autograd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tests/test_autograd.py') diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 636fe86..083d465 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2.half()) + ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: @@ -367,6 +367,8 @@ def test_matmullt( if has_bias: out_torch += bias + assert out_bnb.dtype == torch.dtype + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') -- cgit v1.2.3