summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_autograd.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index fc7a0e1..8ebe8c8 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -40,7 +40,8 @@ names = [
ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
- dim2 = dim2 - (dim2 % 16)
+ if dim2 > 0:
+ dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
@@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
-# dim1 = (17,)
-# dim2 = (7,)
-# dim3 = (37,)
-# dim4 = (23,)
+dim2.append(0)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
@@ -385,9 +383,14 @@ def test_matmullt(
)
if req_grad[1]:
n = gradB1.numel()
- assert torch.abs(gradB1).sum() > 0.0
- assert torch.abs(gradB2).sum() > 0.0
+ if dim2 > 0:
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ else:
+ assert torch.abs(gradB1).sum() == 0.0
+ assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02