summaryrefslogtreecommitdiff
path: root/tests/test_optim.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-11-29 08:21:05 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2021-11-29 08:21:05 -0800
commit108cf9fc1f8c6bc0360a49ce790699928883b3d3 (patch)
tree57ed09eaa584f244f5376894504d2eb042372316 /tests/test_optim.py
parentb3fe8a6d0f53e3e81a4a6bc7385ce86077abf690 (diff)
Fixed unsafe use of eval. #8
Diffstat (limited to 'tests/test_optim.py')
-rw-r--r--tests/test_optim.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py
index d306511..5464043 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -392,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
#assert s < 3.9
+
+def test_str_betas():
+ betas = (0.80, 0.95)
+ strbetas = '(0.80, 0.95)'
+
+ layer = torch.nn.Linear(10, 10)
+
+ base = bnb.optim.Adam(layer.parameters(), betas=betas)
+ strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
+ assert base.defaults['betas'][0] == 0.8
+ assert base.defaults['betas'][1] == 0.95
+ assert strbase.defaults['betas'][0] == 0.8
+ assert strbase.defaults['betas'][1] == 0.95
+
+