diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:12:58 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-17 23:12:58 +0300 |
commit | 140cdbe8767247bb9b8ea510755cceaa304b6859 (patch) | |
tree | bb18e7d1551edc79c0c4737d065a092102413d07 | |
parent | a9c7953e0a68a934a18a9495b20deeed9665b2a6 (diff) |
check dtypes first
-rw-r--r-- | tests/test_autograd.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 636fe86..083d465 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2.half()) + ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: @@ -367,6 +367,8 @@ def test_matmullt( if has_bias: out_torch += bias + assert out_bnb.dtype == torch.dtype + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') |