diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2021-11-29 08:21:05 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2021-11-29 08:21:05 -0800 |
commit | 108cf9fc1f8c6bc0360a49ce790699928883b3d3 (patch) | |
tree | 57ed09eaa584f244f5376894504d2eb042372316 /tests | |
parent | b3fe8a6d0f53e3e81a4a6bc7385ce86077abf690 (diff) |
Fixed unsafe use of eval. #8
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_optim.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tests/test_optim.py b/tests/test_optim.py index d306511..5464043 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -392,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): #assert s < 3.9 + +def test_str_betas(): + betas = (0.80, 0.95) + strbetas = '(0.80, 0.95)' + + layer = torch.nn.Linear(10, 10) + + base = bnb.optim.Adam(layer.parameters(), betas=betas) + strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas) + assert base.defaults['betas'][0] == 0.8 + assert base.defaults['betas'][1] == 0.95 + assert strbase.defaults['betas'][0] == 0.8 + assert strbase.defaults['betas'][1] == 0.95 + + |