summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_autograd.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index d2b5d59..1b6c2ab 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -23,7 +23,8 @@ str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, st
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
- dim2 = dim2 - (dim2 % 16)
+ if dim2 > 0:
+ dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
@@ -179,6 +180,7 @@ dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim2.append(0)
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
@@ -234,9 +236,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
err = torch.abs(out_bnb-out_torch).mean().item()
#print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx==0).sum().item() <= n*0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx==0).sum().item() <= n*0.001
if has_fp16_weights:
if any(req_grad):
@@ -260,11 +262,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
- assert torch.abs(gradB1).sum() > 0.0
- assert torch.abs(gradB2).sum() > 0.0
+ if dim2 > 0:
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ else:
+ assert torch.abs(gradB1).sum() == 0.0
+ assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx==0).sum().item() <= n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx==0).sum().item() <= n*0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)