diff options
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r-- | bitsandbytes/optim/__init__.py | 19 | ||||
-rw-r--r-- | bitsandbytes/optim/rmsprop.py | 2 |
2 files changed, 13 insertions, 8 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 5e73414..42b5bc0 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -2,11 +2,16 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .adam import Adam, Adam8bit, Adam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit -from .sgd import SGD, SGD8bit, SGD32bit -from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS -from .lamb import LAMB, LAMB8bit, LAMB32bit -from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit + +from bitsandbytes.cextension import COMPILED_WITH_CUDA + +if COMPILED_WITH_CUDA: + from .adam import Adam, Adam8bit, Adam32bit + from .adamw import AdamW, AdamW8bit, AdamW32bit + from .sgd import SGD, SGD8bit, SGD32bit + from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS + from .lamb import LAMB, LAMB8bit, LAMB32bit + from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit + from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit + from .optimizer import GlobalOptimManager diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 7909d5d..0f1ffaa 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -31,6 +31,6 @@ class RMSprop32bit(Optimizer1State): if alpha == 0: raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') if centered: - raise NotImplementError(f'Centered RMSprop is not supported!') + raise NotImplementedError(f'Centered RMSprop is not supported!') super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) |