From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- tests/test_autograd.py | 96 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 74 insertions(+), 22 deletions(-) (limited to 'tests/test_autograd.py') diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 9cd01a9..fc7a0e1 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"] transpose = [(False, False), (False, True), (True, True), (True, False)] str_transpose = ["FF", "FT", "TT", "TF"] dtype = [torch.float32, torch.float16] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) +values = list( + product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +) str_values = list( - product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose) + product( + dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose + ) ) names = [ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( @@ -31,7 +35,9 @@ names = [ @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", + values, + ids=names, ) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dim2 = dim2 - (dim2 % 16) @@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: A = torch.randn( - size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], ) B = torch.randn( - size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1] + size=(dim1, dim3, dim4), + device="cuda", + requires_grad=req_grad[1], ) target = torch.randn( - size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], ) torch.nn.init.xavier_uniform_(B) @@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) + torch.testing.assert_allclose( + out_bnb, out_torch, atol=0.027, rtol=0.2 + ) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if funcs[0] in [torch.matmul]: dim1 = dim1 - (dim1 % 16) A = torch.randn( - size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], ) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) target = torch.randn( - size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], ) torch.nn.init.xavier_uniform_(B) @@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -258,7 +290,16 @@ names = [ ids=names, ) def test_matmullt( - dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, ): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) @@ -278,7 +319,10 @@ def test_matmullt( size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype ) target = torch.randn( - size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype + size=(dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + dtype=dtype, ) torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -317,14 +361,18 @@ def test_matmullt( if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb = torch.nn.functional.mse_loss( + out_bnb, target + ).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -332,7 +380,9 @@ def test_matmullt( B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() assert torch.abs(gradB1).sum() > 0.0 @@ -341,4 +391,6 @@ def test_matmullt( assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) -- cgit v1.2.3