summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
authorMax Ryabinin <mryabinin0@gmail.com>2022-07-01 17:16:10 +0300
committerMax Ryabinin <mryabinin0@gmail.com>2022-07-01 17:16:10 +0300
commit8258b4364a21a4da2572cb644d0926080c3268da (patch)
tree571e95bc327116fbaba08d14871fb0b224b8a65b /bitsandbytes/optim
parent33efe4a09f459832e8beceba70add0695cc485e4 (diff)
Add a CPU-only build option
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/__init__.py20
-rw-r--r--bitsandbytes/optim/rmsprop.py2
2 files changed, 13 insertions, 9 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 5e73414..e833ecc 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -2,11 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .adam import Adam, Adam8bit, Adam32bit
-from .adamw import AdamW, AdamW8bit, AdamW32bit
-from .sgd import SGD, SGD8bit, SGD32bit
-from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
-from .lamb import LAMB, LAMB8bit, LAMB32bit
-from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
-from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
-from .optimizer import GlobalOptimManager
+
+from bitsandbytes.cextension import COMPILED_WITH_CUDA
+
+if COMPILED_WITH_CUDA:
+ from .adam import Adam, Adam8bit, Adam32bit
+ from .adamw import AdamW, AdamW8bit, AdamW32bit
+ from .sgd import SGD, SGD8bit, SGD32bit
+ from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
+ from .lamb import LAMB, LAMB8bit, LAMB32bit
+ from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
+ from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
+ from .optimizer import GlobalOptimManager
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index 7909d5d..0f1ffaa 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -31,6 +31,6 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
- raise NotImplementError(f'Centered RMSprop is not supported!')
+ raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)