summaryrefslogtreecommitdiff
path: root/tests/test_autograd.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-04 08:03:00 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-04 08:03:00 -0700
commit758c7175a24df307c40b743b1def8b4c34f68674 (patch)
treed7046117149950c2e97a5af6bd99d87f7688a357 /tests/test_autograd.py
parent96bc209baf55f2e05e649e555c2de5fc478c24dc (diff)
parentab72a1294fda03a0fd4ec297562fdab806349752 (diff)
Merge branch 'debug' into cuda-bin-switch-and-cli
Diffstat (limited to 'tests/test_autograd.py')
-rw-r--r--tests/test_autograd.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index fc7a0e1..8ebe8c8 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -40,7 +40,8 @@ names = [
ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
- dim2 = dim2 - (dim2 % 16)
+ if dim2 > 0:
+ dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
@@ -234,10 +235,7 @@ dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32, 96, size=(n,)).tolist()
-# dim1 = (17,)
-# dim2 = (7,)
-# dim3 = (37,)
-# dim4 = (23,)
+dim2.append(0)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
@@ -385,9 +383,14 @@ def test_matmullt(
)
if req_grad[1]:
n = gradB1.numel()
- assert torch.abs(gradB1).sum() > 0.0
- assert torch.abs(gradB2).sum() > 0.0
+ if dim2 > 0:
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ else:
+ assert torch.abs(gradB1).sum() == 0.0
+ assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02