diff options
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r-- | bitsandbytes/optim/__init__.py | 2 | ||||
-rw-r--r-- | bitsandbytes/optim/adagrad.py | 57 | ||||
-rw-r--r-- | bitsandbytes/optim/adam.py | 1 | ||||
-rw-r--r-- | bitsandbytes/optim/adamw.py | 29 | ||||
-rw-r--r-- | bitsandbytes/optim/optimizer.py | 5 |
5 files changed, 91 insertions, 3 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 92c83b1..5e73414 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,8 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .adam import Adam, Adam8bit, Adam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit from .sgd import SGD, SGD8bit, SGD32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lamb import LAMB, LAMB8bit, LAMB32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .optimizer import GlobalOptimManager diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py new file mode 100644 index 0000000..84ade3c --- /dev/null +++ b/bitsandbytes/optim/adagrad.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from bitsandbytes.optim.optimizer import Optimizer1State + +torch.optim.Adagrad + +class Adagrad(Optimizer1State): + def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, + optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if initial_accumulator_value != 0.0: + raise ValueError('Initial accumulator value != 0.0 not supported!') + if lr_decay != 0.0: + raise ValueError('Lr Decay != 0.0 not supported!') + super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class Adagrad8bit(Optimizer1State): + def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, + optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if initial_accumulator_value != 0.0: + raise ValueError('Initial accumulator value != 0.0 not supported!') + if lr_decay != 0.0: + raise ValueError('Lr Decay != 0.0 not supported!') + assert block_wise + super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class Adagrad32bit(Optimizer1State): + def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, + optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if initial_accumulator_value != 0.0: + raise ValueError('Initial accumulator value != 0.0 not supported!') + if lr_decay != 0.0: + raise ValueError('Lr Decay != 0.0 not supported!') + super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index f3e5e81..ed1b9f0 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -33,7 +33,6 @@ class Adam32bit(Optimizer2State): weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) - class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py new file mode 100644 index 0000000..7761f3b --- /dev/null +++ b/bitsandbytes/optim/adamw.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from bitsandbytes.optim.optimizer import Optimizer2State +import bitsandbytes.functional as F + +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW, self).__init__('adam', params, lr, betas, eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class AdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class AdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 4b70b5c..cfbd72e 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -242,8 +242,9 @@ class Optimizer2State(Optimizer8bit): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if isinstance(betas, str): - betas = eval(betas) - print(betas, 'parsed') + # format: '(beta1, beta2)' + betas = betas.replace('(', '').replace(')', '').strip().split(',') + betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") |