summaryrefslogtreecommitdiff
path: root/tests/test_autograd.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_autograd.py')
-rw-r--r--tests/test_autograd.py273
1 files changed, 198 insertions, 75 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 1b6c2ab..8ebe8c8 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -1,27 +1,44 @@
-import pytest
+from itertools import product
+import pytest
import torch
-import bitsandbytes as bnb
-from itertools import product
+import bitsandbytes as bnb
n = 1
k = 25
-dim1 = torch.randint(16,64, size=(n,)).tolist()
-dim2 = torch.randint(32,96, size=(n,)).tolist()
-dim3 = torch.randint(32,96, size=(n,)).tolist()
-dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim1 = torch.randint(16, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 96, size=(n,)).tolist()
+dim3 = torch.randint(32, 96, size=(n,)).tolist()
+dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
-str_funcs = ['bmm', 'matmul']
+str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
-req_grad_str = ['FF', 'TF', 'TT', 'FT']
+req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
-str_transpose = ['FF', 'FT', 'TT', 'TF']
+str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
-values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
-str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
+values = list(
+ product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
+)
+str_values = list(
+ product(
+ dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
+ )
+)
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
+ *vals
+ )
+ for vals in str_values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
+ values,
+ ids=names,
+)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if dim2 > 0:
dim2 = dim2 - (dim2 % 16)
@@ -33,9 +50,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
- A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
- B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
- target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
+ B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
+ target = torch.randn(
+ size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
+ )
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
@@ -53,9 +72,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -67,7 +86,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -75,20 +96,36 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ assert (idx == 0).sum().item() < n * 0.02
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
- A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
- B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
- target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ A = torch.randn(
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
+ )
+ B = torch.randn(
+ size=(dim1, dim3, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ )
+ target = torch.randn(
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ )
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
@@ -96,8 +133,10 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.01
- torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
+ assert (idx == 0).sum().item() < n * 0.01
+ torch.testing.assert_allclose(
+ out_bnb, out_torch, atol=0.027, rtol=0.2
+ )
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -109,7 +148,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -117,20 +158,30 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
- A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
+ A = torch.randn(
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
+ )
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
- B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
- target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
+ target = torch.randn(
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ )
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
@@ -142,9 +193,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() < n*0.0175
+ assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() < n*0.001
+ assert (idx == 0).sum().item() < n * 0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -156,7 +207,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -164,56 +217,111 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() < n*0.1
+ assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() < n*0.02
+ assert (idx == 0).sum().item() < n * 0.02
n = 1
k = 3
-dim1 = torch.randint(16,64, size=(n,)).tolist()
-dim2 = torch.randint(32,96, size=(n,)).tolist()
-dim3 = torch.randint(32,96, size=(n,)).tolist()
-dim4 = torch.randint(32,96, size=(n,)).tolist()
+dim1 = torch.randint(16, 64, size=(n,)).tolist()
+dim2 = torch.randint(32, 96, size=(n,)).tolist()
+dim3 = torch.randint(32, 96, size=(n,)).tolist()
+dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2.append(0)
-#dim1 = (17,)
-#dim2 = (7,)
-#dim3 = (37,)
-#dim4 = (23,)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
-str_funcs = ['matmul']
+str_funcs = ["matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)]
-req_grad_str = ['FF', 'TF', 'TT', 'FT']
+req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, True), (False, False)]
-str_transpose = ['NT', 'NN']
+str_transpose = ["NT", "NN"]
dtype = [torch.float16]
has_fp16_weights = [True, False]
-values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
-str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
-names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
-@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
-def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
+values = list(
+ product(
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ funcs,
+ dtype,
+ req_grad,
+ transpose,
+ decomp,
+ has_fp16_weights,
+ )
+)
+str_values = list(
+ product(
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ str_funcs,
+ dtype,
+ req_grad_str,
+ str_transpose,
+ decomp,
+ has_fp16_weights,
+ )
+)
+names = [
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
+ *vals
+ )
+ for vals in str_values
+]
+
+
+@pytest.mark.parametrize(
+ "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
+ values,
+ ids=names,
+)
+def test_matmullt(
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ funcs,
+ dtype,
+ req_grad,
+ transpose,
+ decomp,
+ has_fp16_weights,
+):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
- outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
+ outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
- A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
+ A = torch.randn(
+ size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
+ )
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
- B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
- target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
+ B = torch.randn(
+ size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
+ )
+ target = torch.randn(
+ size=(dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ dtype=dtype,
+ )
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
@@ -221,8 +329,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state.threshold = decomp
state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights:
- if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
- state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
+ if not transpose[0] and not transpose[1]:
+ B2 = B2.t().contiguous()
+ (
+ state.CB,
+ CBt,
+ state.SCB,
+ SCBt,
+ coo_tensorB,
+ ) = bnb.functional.double_quant(B2)
B2 = state.CB
if not transpose[0] and transpose[1]:
@@ -233,25 +348,29 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb = funcs[1](A, B2.t(), state=state)
n = out_bnb.numel()
- err = torch.abs(out_bnb-out_torch).mean().item()
- #print(f'abs error {err:.4f}')
+ err = torch.abs(out_bnb - out_torch).mean().item()
+ # print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
- assert (idx==0).sum().item() <= n*0.0175
+ assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
- assert (idx==0).sum().item() <= n*0.001
+ assert (idx == 0).sum().item() < n * 0.001
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb = torch.nn.functional.mse_loss(
+ out_bnb, target
+ ).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -259,7 +378,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
@@ -269,8 +390,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
- assert (idx==0).sum().item() <= n*0.1
- idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
- assert (idx==0).sum().item() <= n*0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ assert (idx == 0).sum().item() < n * 0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx == 0).sum().item() < n * 0.02
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )