summaryrefslogtreecommitdiff
path: root/tests/test_optim.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optim.py')
-rw-r--r--tests/test_optim.py71
1 files changed, 55 insertions, 16 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py
index b84425e..8e12761 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
-str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
+str2statenames["momentum8bit"] = [
+ ("momentum_buffer", "state1", "qmap1", "max1")
+]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
-str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
+str2statenames["rmsprop8bit_blockwise"] = [
+ ("square_avg", "state1", "qmap1", "absmax1")
+]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
values = list(product(dim1, dim2, gtype, optimizer_names))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
+ bnb.optim.GlobalOptimManager.get_instance().override_config(
+ p3, "optim_bits", 8
+ )
- bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
+ bnb.optim.GlobalOptimManager.get_instance().register_parameters(
+ [p1, p2, p3]
+ )
p1 = p1.cuda()
p2 = p2.cuda()
p3 = p3.cuda()
@@ -245,7 +255,9 @@ optimizer_names = [
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
- torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
- torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
+ torch.testing.assert_allclose(
+ raws1cpy, bnb_optimizer.state[p2][name2]
+ )
+ torch.testing.assert_allclose(
+ qmap1, bnb_optimizer.state[p2][qmap]
+ )
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
@@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close = (
torch.isclose(
- torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
+ torch_optimizer.state[p1][name1],
+ s1,
+ atol=atol,
+ rtol=rtol,
)
== 0
)
assert num_not_close.sum().item() < 20
- torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
+ torch.testing.assert_allclose(
+ p1, p2.float(), atol=patol, rtol=prtol
+ )
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
@@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32]
optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
+ for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
@@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam(
- [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
+ [p2],
+ lr,
+ (beta1, beta2),
+ eps,
+ optim_bits=optim_bits,
+ percentile_clipping=5,
)
gnorm_vec = torch.zeros(100).cuda()
@@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for i in range(50):
step += 1
- g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
+ g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
+ 0.01 * i
+ )
g2 = g1.clone()
p2.grad = g2
@@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(
- adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
+ adam1.state[p1]["state1"],
+ adam2.state[p2]["state1"],
+ atol=2,
+ rtol=1e-3,
)
torch.testing.assert_allclose(
- adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
+ adam1.state[p1]["state2"],
+ adam2.state[p2]["state2"],
+ atol=2,
+ rtol=1e-3,
)
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
@@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names))
-names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
+names = [
+ "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
+]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)