summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/__init__.py1
-rw-r--r--bitsandbytes/optim/adam.py1
-rw-r--r--bitsandbytes/optim/adamw.py29
3 files changed, 30 insertions, 1 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index af8a488..5e73414 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .adam import Adam, Adam8bit, Adam32bit
+from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index eb951ee..1e93a60 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -28,7 +28,6 @@ class Adam32bit(Optimizer2State):
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
-
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
new file mode 100644
index 0000000..7761f3b
--- /dev/null
+++ b/bitsandbytes/optim/adamw.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer2State
+import bitsandbytes.functional as F
+
+class AdamW(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class AdamW8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class AdamW32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+