summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
authorTim Dettmers <dettmers@g3036.hyak.local>2021-11-10 15:10:02 -0800
committerTim Dettmers <dettmers@g3036.hyak.local>2021-11-10 15:10:02 -0800
commit8b3c0f355c779170d55a1975df981df9e53b59fa (patch)
tree0ebc5f8e869fb02e7dec90f809fbf07d778f9aca /bitsandbytes/optim
parent22b2877c7f8277317a073ea7cf49231d33fe79fd (diff)
Added adagrad with tests (no clipping).
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/__init__.py1
-rw-r--r--bitsandbytes/optim/adagrad.py57
2 files changed, 58 insertions, 0 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 92c83b1..af8a488 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -7,4 +7,5 @@ from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
+from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
new file mode 100644
index 0000000..84ade3c
--- /dev/null
+++ b/bitsandbytes/optim/adagrad.py
@@ -0,0 +1,57 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+torch.optim.Adagrad
+
+class Adagrad(Optimizer1State):
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
+ optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if initial_accumulator_value != 0.0:
+ raise ValueError('Initial accumulator value != 0.0 not supported!')
+ if lr_decay != 0.0:
+ raise ValueError('Lr Decay != 0.0 not supported!')
+ super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adagrad8bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
+ optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if initial_accumulator_value != 0.0:
+ raise ValueError('Initial accumulator value != 0.0 not supported!')
+ if lr_decay != 0.0:
+ raise ValueError('Lr Decay != 0.0 not supported!')
+ assert block_wise
+ super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adagrad32bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
+ optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if initial_accumulator_value != 0.0:
+ raise ValueError('Initial accumulator value != 0.0 not supported!')
+ if lr_decay != 0.0:
+ raise ValueError('Lr Decay != 0.0 not supported!')
+ super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)