summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py2
-rw-r--r--bitsandbytes/optim/__init__.py2
-rw-r--r--bitsandbytes/optim/adagrad.py57
-rw-r--r--bitsandbytes/optim/adam.py1
-rw-r--r--bitsandbytes/optim/adamw.py29
-rw-r--r--bitsandbytes/optim/optimizer.py5
6 files changed, 93 insertions, 3 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 9fe1345..44116cc 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -19,6 +19,7 @@ str2optimizer32bit = {}
str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
+str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
@@ -33,6 +34,7 @@ str2optimizer8bit_blockwise = {}
str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
+str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0]
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 92c83b1..5e73414 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -3,8 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .adam import Adam, Adam8bit, Adam32bit
+from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
+from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
new file mode 100644
index 0000000..84ade3c
--- /dev/null
+++ b/bitsandbytes/optim/adagrad.py
@@ -0,0 +1,57 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+torch.optim.Adagrad
+
+class Adagrad(Optimizer1State):
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
+ optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if initial_accumulator_value != 0.0:
+ raise ValueError('Initial accumulator value != 0.0 not supported!')
+ if lr_decay != 0.0:
+ raise ValueError('Lr Decay != 0.0 not supported!')
+ super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adagrad8bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
+ optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if initial_accumulator_value != 0.0:
+ raise ValueError('Initial accumulator value != 0.0 not supported!')
+ if lr_decay != 0.0:
+ raise ValueError('Lr Decay != 0.0 not supported!')
+ assert block_wise
+ super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adagrad32bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
+ optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if initial_accumulator_value != 0.0:
+ raise ValueError('Initial accumulator value != 0.0 not supported!')
+ if lr_decay != 0.0:
+ raise ValueError('Lr Decay != 0.0 not supported!')
+ super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index f3e5e81..ed1b9f0 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -33,7 +33,6 @@ class Adam32bit(Optimizer2State):
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
-
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
new file mode 100644
index 0000000..7761f3b
--- /dev/null
+++ b/bitsandbytes/optim/adamw.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer2State
+import bitsandbytes.functional as F
+
+class AdamW(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class AdamW8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class AdamW32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 4b70b5c..cfbd72e 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -242,8 +242,9 @@ class Optimizer2State(Optimizer8bit):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
- betas = eval(betas)
- print(betas, 'parsed')
+ # format: '(beta1, beta2)'
+ betas = betas.replace('(', '').replace(')', '').strip().split(',')
+ betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")