diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_autograd.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py index fc7a0e1..8ebe8c8 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -40,7 +40,8 @@ names = [ ids=names, ) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): - dim2 = dim2 - (dim2 % 16) + if dim2 > 0: + dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) dim4 = dim4 - (dim4 % 16) for i in range(k): @@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist() -# dim1 = (17,) -# dim2 = (7,) -# dim3 = (37,) -# dim4 = (23,) +dim2.append(0) decomp = [0.0, 6.0] funcs = [(torch.matmul, bnb.matmul)] @@ -385,9 +383,14 @@ def test_matmullt( ) if req_grad[1]: n = gradB1.numel() - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 |