summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_autograd.py96
-rw-r--r--tests/test_cuda_setup_evaluator.py33
-rw-r--r--tests/test_functional.py187
-rw-r--r--tests/test_modules.py50
-rw-r--r--tests/test_optim.py71
5 files changed, 332 insertions, 105 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
index 9cd01a9..fc7a0e1 100644
--- a/tests/test_autograd.py
+++ b/tests/test_autograd.py
@@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16]
-values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
+values = list(
+ product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
+)
str_values = list(
- product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)
+ product(
+ dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
+ )
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
@@ -31,7 +35,9 @@ names = [
@pytest.mark.parametrize(
- "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names
+ "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
+ values,
+ ids=names,
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16)
@@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(
- size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
)
B = torch.randn(
- size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1]
+ size=(dim1, dim3, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
)
target = torch.randn(
- size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
@@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01
- torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
+ torch.testing.assert_allclose(
+ out_bnb, out_torch, atol=0.027, rtol=0.2
+ )
if any(req_grad):
out_bnb.data.copy_(out_torch)
@@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(
- size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
+ size=(dim1, dim2, dim3),
+ device="cuda",
+ requires_grad=req_grad[0],
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(
- size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
+ size=(dim1, dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
)
torch.nn.init.xavier_uniform_(B)
@@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@@ -258,7 +290,16 @@ names = [
ids=names,
)
def test_matmullt(
- dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
+ dim1,
+ dim2,
+ dim3,
+ dim4,
+ funcs,
+ dtype,
+ req_grad,
+ transpose,
+ decomp,
+ has_fp16_weights,
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
@@ -278,7 +319,10 @@ def test_matmullt(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
- size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype
+ size=(dim2, dim4),
+ device="cuda",
+ requires_grad=req_grad[1],
+ dtype=dtype,
)
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
@@ -317,14 +361,18 @@ def test_matmullt(
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
- loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb = torch.nn.functional.mse_loss(
+ out_bnb, target
+ ).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
- loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch = torch.nn.functional.mse_loss(
+ out_torch, target
+ ).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
@@ -332,7 +380,9 @@ def test_matmullt(
B.grad = None
if req_grad[0]:
- torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ torch.testing.assert_allclose(
+ gradA1, gradA2, atol=0.015, rtol=0.1
+ )
if req_grad[1]:
n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0
@@ -341,4 +391,6 @@ def test_matmullt(
assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02
- torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+ torch.testing.assert_allclose(
+ gradB1, gradB2, atol=0.18, rtol=0.3
+ )
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py
index d45354f..5a58be4 100644
--- a/tests/test_cuda_setup_evaluator.py
+++ b/tests/test_cuda_setup_evaluator.py
@@ -3,8 +3,12 @@ from typing import List, NamedTuple
import pytest
-from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
- get_cuda_runtime_lib_path, tokenize_paths)
+from bitsandbytes.cuda_setup import (
+ CUDA_RUNTIME_LIB,
+ evaluate_cuda_setup,
+ get_cuda_runtime_lib_path,
+ tokenize_paths,
+)
class InputAndExpectedOutput(NamedTuple):
@@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple):
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
- (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
- (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
- (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
- (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
- (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
+ (
+ f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
+ (
+ f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
(
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 11cd198..ab7d672 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -86,7 +86,9 @@ def teardown():
pass
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
+@pytest.mark.parametrize(
+ "dtype", [torch.float32, torch.float16], ids=["float", "half"]
+)
def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype)
@@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization():
)
-@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
+@pytest.mark.parametrize(
+ "gtype", [torch.float32, torch.float16], ids=["float", "half"]
+)
def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda")
@@ -270,7 +274,13 @@ def mean(xx):
dim1 = [1024 * 2]
dim2 = [1024 * 16]
methods = [
- (lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)
+ (
+ lambda x, dim: quant(x),
+ lambda x, dim: quant(x),
+ dequant,
+ dequant,
+ mm_dequant,
+ )
]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
@@ -279,11 +289,14 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched))
names = [
- "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names
+ "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
+ for vals in values_names
]
-@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
+@pytest.mark.parametrize(
+ "dim1, dim2, quant_methods, batched", values, ids=names
+)
def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
@@ -339,14 +352,18 @@ names = [
]
-@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
+@pytest.mark.parametrize(
+ "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
+)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (
- (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
+ (batch_dim, hidden_dim)
+ if not transpose[0]
+ else (hidden_dim, batch_dim)
)
shapeB = (
(32 * random.randint(1, 4), hidden_dim)
@@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim))
-names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values]
+names = [
+ "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
@@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
A = torch.randint(
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
).to(torch.int8)
- B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(
- torch.int8
- )
+ B = torch.randint(
+ -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
+ ).to(torch.int8)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
- iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
+ iout = torch.empty(
+ A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
+ )
out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2)
@@ -428,7 +449,9 @@ names = [
]
-@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
+@pytest.mark.parametrize(
+ "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
+)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = []
relerrs2 = []
for i in range(k):
- A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
+ A = torch.normal(
+ 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
+ )
if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
@@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [
- "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
+ for vals in values
]
@@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
- out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
+ out2 = torch.bmm(
+ A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
+ )
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float())
@@ -563,7 +591,9 @@ a_order = ["row"]
out_order = ["col", "row", "col32"]
transpose = [False]
dims = [2, 3]
-values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
+values = list(
+ product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
+)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
@@ -574,9 +604,13 @@ names = [
@pytest.mark.parametrize(
- "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
+ "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
+ values,
+ ids=names,
)
-def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
+def test_nvidia_transform(
+ dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
+):
if dims == 3 and out_order != "col32":
return
if dtype == torch.int32 and out_order != "col32":
@@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
+ dtype
+ )
out, S = F.nvidia_transform(A, to_order=orderOut)
@@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3:
- n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
+ n = (
+ A.shape[0]
+ * A.shape[1]
+ * (A.shape[2] + (32 - (A.shape[2] % 32)))
+ )
assert out.numel() == n
elif orderOut == "col_turing":
# 32 col 8 row tiles
@@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
- rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
+ rowtile = (
+ (row // 8) + (1 if row % 8 != 0 else 0)
+ ) * total_coltile
offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32
row2 = (row % 8) * 32
@@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32":
- out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
+ out2, S = F.nvidia_transform(
+ out, from_order=orderOut, to_order="row", state=S
+ )
torch.testing.assert_allclose(A, out2)
@@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.int8
)
elif dims == 3:
- A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
- torch.int8
- )
- B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
+ A = torch.randint(
+ -128, 127, size=(dim1, dim2, dim3), device="cuda"
+ ).to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32")
@@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.testing.assert_allclose(C1, C3.float())
# transpose
- B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
+ torch.int8
+ )
C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True)
@@ -688,7 +736,8 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims))
names = [
- "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values
+ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
+ for vals in values
]
@@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3:
- A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
+ A = torch.normal(
+ 0, 0.5, size=(dim1, dim2, dim3), device="cuda"
+ ).half()
B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t())
@@ -742,7 +793,9 @@ values = [
# values = list(product(batch, seq, model, hidden))
-names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
+names = [
+ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
@@ -909,7 +962,9 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"]
values = list(product(dim1, dim4, dims, formatB))
-names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
@@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims):
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
- row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
+ A, threshold=0.0
+ )
torch.testing.assert_allclose(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2)
@@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel()
- num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
- num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
+ num_not_close_rows = (
+ (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
+ )
+ num_not_close_cols = (
+ (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
+ )
# allow for 1:500 error due to rounding differences
min_error = 1 / 500
@@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c
- outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ outC32, SC = F.igemmlt(
+ A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
+ )
C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max()
if maxval == 127:
@@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner):
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
- outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ outC32, SC = F.igemmlt(
+ A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
+ )
torch.cuda.synchronize()
print("row-wise", time.time() - t0)
@@ -1230,7 +1295,9 @@ a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True]
dims = [2]
-values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
+values = list(
+ product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
+)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
*vals
@@ -1240,14 +1307,20 @@ names = [
@pytest.mark.parametrize(
- "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
+ "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
+ values,
+ ids=names,
)
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k):
if dims == 2:
- A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
+ A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
+ dtype
+ )
elif dims == 3:
- A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
+ A = torch.randint(
+ 10, 99, size=(dim1, dim2, dim3), device="cuda"
+ ).to(dtype)
A.view(-1)[-1] = -1
if transpose:
@@ -1282,7 +1355,9 @@ names = [
]
-@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
+@pytest.mark.parametrize(
+ "dim1, dim2, dtype, orderA, orderOut", values, ids=names
+)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
for i in range(1):
A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
@@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2):
idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
- CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
+ A, threshold=threshold
+ )
if coo_tensor is not None:
A1 = A * idx
A2 = torch.zeros_like(A)
- A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
+ A2[
+ coo_tensor.rowidx.long(), coo_tensor.colidx.long()
+ ] = coo_tensor.values
torch.testing.assert_allclose(A1, A2)
A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
- torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
+ torch.testing.assert_allclose(
+ A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
+ )
n = 2
@@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
- CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
+ A, threshold=threshold
+ )
C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
@@ -1494,7 +1577,9 @@ dim2 = [12288]
dtype = [torch.float16]
out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function))
-names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
@@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std = out1.std()
out1 /= std
out2 /= std
- assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
+ assert_all_approx_close(
+ out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
+ )
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
@@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288))
-names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
+names = [
+ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 6b8d641..7faadb8 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm
std = torch.std(x)
@@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
max1 = torch.abs(x).max()
x = x / max1 * 127
x = round_func(x) / 127 * max1
@@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
- round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ round_func = (
+ LinearFunction.round_stoachastic if stochastic else torch.round
+ )
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = (x * 127) / max1
@@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
- output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
+ output = LinearFunction.dequant(
+ outputq, S1, S2, x.dtype, args.quant_type
+ )
# if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float()
@@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
- grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
+ grad_output8, S1 = LinearFunction.quant(
+ grad_output, args.quant_type, dim=2
+ )
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(
@@ -338,8 +348,12 @@ def test_linear8bit():
loss2.backward()
loss3.backward()
- assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
- assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(
+ l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
+ )
+ assert_all_approx_close(
+ l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
+ )
assert_all_approx_close(
l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
@@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
- l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
+ l2 = torch.nn.Sequential(
+ *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
+ )
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0:
assert mlp.fc2.state.idx is not None
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .cuda()
+ .half()
+ )
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
@@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0:
assert mlp.fc2.state.idx is not None
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .half()
+ .cuda()
+ )
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
- mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
+ mlp = (
+ MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ .half()
+ .to("cuda")
+ )
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
diff --git a/tests/test_optim.py b/tests/test_optim.py
index b84425e..8e12761 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
-str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
+str2statenames["momentum8bit"] = [
+ ("momentum_buffer", "state1", "qmap1", "max1")
+]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
-str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
+str2statenames["rmsprop8bit_blockwise"] = [
+ ("square_avg", "state1", "qmap1", "absmax1")
+]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
values = list(product(dim1, dim2, gtype, optimizer_names))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
+ bnb.optim.GlobalOptimManager.get_instance().override_config(
+ p3, "optim_bits", 8
+ )
- bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
+ bnb.optim.GlobalOptimManager.get_instance().register_parameters(
+ [p1, p2, p3]
+ )
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
@@ -245,7 +255,9 @@ optimizer_names = [
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
- torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
- torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
+ torch.testing.assert_allclose(
+ raws1cpy, bnb_optimizer.state[p2][name2]
+ )
+ torch.testing.assert_allclose(
+ qmap1, bnb_optimizer.state[p2][qmap]
+ )
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
@@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close = (
torch.isclose(
- torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
+ torch_optimizer.state[p1][name1],
+ s1,
+ atol=atol,
+ rtol=rtol,
)
== 0
)
assert num_not_close.sum().item() < 20
- torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
+ torch.testing.assert_allclose(
+ p1, p2.float(), atol=patol, rtol=prtol
+ )
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
@@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
+ for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
@@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam(
- [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
+ [p2],
+ lr,
+ (beta1, beta2),
+ eps,
+ optim_bits=optim_bits,
+ percentile_clipping=5,
)
gnorm_vec = torch.zeros(100).cuda()
@@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for i in range(50):
step += 1
- g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
+ g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
+ 0.01 * i
+ )
g2 = g1.clone()
p2.grad = g2
@@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(
- adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
+ adam1.state[p1]["state1"],
+ adam2.state[p2]["state1"],
+ atol=2,
+ rtol=1e-3,
)
torch.testing.assert_allclose(
- adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
+ adam1.state[p1]["state2"],
+ adam2.state[p2]["state2"],
+ atol=2,
+ rtol=1e-3,
)
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
@@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)