summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_autograd.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 8ebe8c8..f1a15f5 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -351,9 +351,9 @@ def test_matmullt(
err = torch.abs(out_bnb - out_torch).mean().item()
# print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx == 0).sum().item() < n * 0.0175
+ assert (idx == 0).sum().item() <= n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx == 0).sum().item() < n * 0.001
+ assert (idx == 0).sum().item() <= n * 0.001
if has_fp16_weights:
if any(req_grad):
@@ -391,9 +391,9 @@ def test_matmullt(
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.1
+ assert (idx == 0).sum().item() <= n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx == 0).sum().item() < n * 0.02
+ assert (idx == 0).sum().item() <= n * 0.02
torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)