summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_optim.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py
index d306511..5464043 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -392,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
#assert s < 3.9
+
+def test_str_betas():
+ betas = (0.80, 0.95)
+ strbetas = '(0.80, 0.95)'
+
+ layer = torch.nn.Linear(10, 10)
+
+ base = bnb.optim.Adam(layer.parameters(), betas=betas)
+ strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
+ assert base.defaults['betas'][0] == 0.8
+ assert base.defaults['betas'][1] == 0.95
+ assert strbase.defaults['betas'][0] == 0.8
+ assert strbase.defaults['betas'][1] == 0.95
+
+