summaryrefslogtreecommitdiff
path: root/tests/test_autograd.py
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 09:32:47 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 09:32:47 -0700
commitea7c14f8ef64924f2d0ff80df3cdabf2c7299848 (patch)
tree3b9ec443a259cf36d87627a8e2cc7d13513f6a21 /tests/test_autograd.py
parent3fd06fb6206f46b6d18fbb8a512da63832dea98b (diff)
reran black with linelength 80 for greater readability
Diffstat (limited to 'tests/test_autograd.py')
-rw-r--r--tests/test_autograd.py96
1 files changed, 74 insertions, 22 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 9cd01a9..fc7a0e1 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
-values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
+values = list(
+ product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
+)
str_values = list(
- product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)
+ product(
+ dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
+ )
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
@@ -31,7 +35,9 @@ names = [
@pytest.mark.parametrize(
- "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names
+ "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
+ values,
+ ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16)
@@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
- size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
)
B = torch.randn(
- size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1]
+ size=(dim1, dim3, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
)
target = torch.randn(
- size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
@@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
- torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
+ torch.testing.assert_allclose(
+ out_bnb, out_torch, atol=0.027, rtol=0.2
+ )
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
- size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
- size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
@@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@@ -258,7 +290,16 @@ names = [
ids=names,
)
def test_matmullt(
- dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ funcs,
+ dtype,
+ req_grad,
+ transpose,
+ decomp,
+ has_fp16_weights,
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
@@ -278,7 +319,10 @@ def test_matmullt(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
- size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype
+ size=(dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ dtype=dtype,
)
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
@@ -317,14 +361,18 @@ def test_matmullt(
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb = torch.nn.functional.mse_loss(
+ out_bnb, target
+ ).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -332,7 +380,9 @@ def test_matmullt(
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0
@@ -341,4 +391,6 @@ def test_matmullt(
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )