summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-11-28 21:18:11 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2021-11-28 21:18:11 -0800
commit2f8083bd8b084290f888fe59b329d98ebd6dd468 (patch)
treeda534579bd762e93cd42b69a5e14c36f4b643979
parentca2078a697ae3adfb84255ae398f79623dc4ea2a (diff)
Added AdamW. #10 #13
-rw-r--r--CHANGELOG.md4
-rw-r--r--Makefile19
-rw-r--r--bitsandbytes/optim/__init__.py1
-rw-r--r--bitsandbytes/optim/adam.py1
-rw-r--r--bitsandbytes/optim/adamw.py29
-rw-r--r--csrc/kernels.cu3
-rw-r--r--tests/test_optim.py10
7 files changed, 54 insertions, 13 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index beaa256..e943fa2 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -42,3 +42,7 @@ Docs:
Features:
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer
+ - Added AdamW (copy of Adam with weight decay init 1e-2)
+
+Bug fixes:
+ - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam
diff --git a/Makefile b/Makefile
index 6055541..5093410 100644
--- a/Makefile
+++ b/Makefile
@@ -19,15 +19,16 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags
-COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
-COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
-COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
-COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
-COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
-COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
-COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
-COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
-COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
+#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
+#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
+#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
+#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
+#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
+#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
+#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
+#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
+#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
+COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index af8a488..5e73414 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .adam import Adam, Adam8bit, Adam32bit
+from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index eb951ee..1e93a60 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -28,7 +28,6 @@ class Adam32bit(Optimizer2State):
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
-
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
new file mode 100644
index 0000000..7761f3b
--- /dev/null
+++ b/bitsandbytes/optim/adamw.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer2State
+import bitsandbytes.functional as F
+
+class AdamW(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class AdamW8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class AdamW32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 56f6a76..d0aabff 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+
+ if(weight_decay > 0.0f)
+ p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
break;
}
diff --git a/tests/test_optim.py b/tests/test_optim.py
index ff0734b..d306511 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -34,6 +34,7 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx,
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
+str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW)
str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
@@ -47,12 +48,14 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
+str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
+str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
@@ -61,6 +64,7 @@ str2statenames['adagrad'] = [('sum', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
+str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
@@ -71,7 +75,7 @@ str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
+optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -86,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
- atol, rtol = 1e-6, 1e-5
+ atol, rtol = 2e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
@@ -201,7 +205,7 @@ def test_global_config(dim1, dim2, gtype):
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
+optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)