From 451fd9506e215aa25643e9782cb7d8aed2a266cc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 11:54:01 -0700 Subject: Added fixes for the case that matmullt dim A is zero, e.g. [0, 768]. --- tests/test_autograd.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'tests/test_autograd.py') diff --git a/tests/test_autograd.py b/tests/test_autograd.py index d2b5d59..1b6c2ab 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] @pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): - dim2 = dim2 - (dim2 % 16) + if dim2 > 0: + dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) for i in range(k): @@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist() dim3 = torch.randint(32,96, size=(n,)).tolist() dim4 = torch.randint(32,96, size=(n,)).tolist() +dim2.append(0) #dim1 = (17,) #dim2 = (7,) #dim3 = (37,) @@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec err = torch.abs(out_bnb-out_torch).mean().item() #print(f'abs error {err:.4f}') idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx==0).sum().item() <= n*0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx==0).sum().item() <= n*0.001 if has_fp16_weights: if any(req_grad): @@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) if req_grad[1]: n = gradB1.numel() - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx==0).sum().item() <= n*0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx==0).sum().item() <= n*0.02 torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) -- cgit v1.2.3