summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.md7
-rw-r--r--bitsandbytes/optim/optimizer.py5
-rw-r--r--tests/test_optim.py15
3 files changed, 22 insertions, 5 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e943fa2..d12af22 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -41,8 +41,9 @@ Docs:
### 0.26.0:
Features:
- - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer
- - Added AdamW (copy of Adam with weight decay init 1e-2)
+ - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer.
+ - Added AdamW (copy of Adam with weight decay init 1e-2). #10
Bug fixes:
- - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam
+ - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
+ - Fixed an unsafe use of eval. #8
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 4b70b5c..cfbd72e 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -242,8 +242,9 @@ class Optimizer2State(Optimizer8bit):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
- betas = eval(betas)
- print(betas, 'parsed')
+ # format: '(beta1, beta2)'
+ betas = betas.replace('(', '').replace(')', '').strip().split(',')
+ betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
diff --git a/tests/test_optim.py b/tests/test_optim.py
index d306511..5464043 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -392,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
#assert s < 3.9
+
+def test_str_betas():
+ betas = (0.80, 0.95)
+ strbetas = '(0.80, 0.95)'
+
+ layer = torch.nn.Linear(10, 10)
+
+ base = bnb.optim.Adam(layer.parameters(), betas=betas)
+ strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
+ assert base.defaults['betas'][0] == 0.8
+ assert base.defaults['betas'][1] == 0.95
+ assert strbase.defaults['betas'][0] == 0.8
+ assert strbase.defaults['betas'][1] == 0.95
+
+