diff options
-rw-r--r-- | CHANGELOG.md | 9 | ||||
-rw-r--r-- | Makefile | 31 | ||||
-rw-r--r-- | bitsandbytes/functional.py | 2 | ||||
-rw-r--r-- | bitsandbytes/optim/__init__.py | 2 | ||||
-rw-r--r-- | bitsandbytes/optim/adagrad.py | 57 | ||||
-rw-r--r-- | bitsandbytes/optim/adam.py | 1 | ||||
-rw-r--r-- | bitsandbytes/optim/adamw.py | 29 | ||||
-rw-r--r-- | bitsandbytes/optim/optimizer.py | 5 | ||||
-rw-r--r-- | csrc/kernels.cu | 25 | ||||
-rw-r--r-- | csrc/ops.cu | 8 | ||||
-rw-r--r-- | csrc/ops.cuh | 1 | ||||
-rw-r--r-- | csrc/pythonInterface.c | 8 | ||||
-rw-r--r-- | tests/test_optim.py | 29 |
13 files changed, 186 insertions, 21 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 683f437..d12af22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,3 +38,12 @@ Docs: - Added docs with instructions to compile from source. +### 0.26.0: + +Features: + - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer. + - Added AdamW (copy of Adam with weight decay init 1e-2). #10 + +Bug fixes: + - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 + - Fixed an unsafe use of eval. #8 @@ -19,15 +19,16 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags -COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler -COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler -COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell -COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell -COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal -COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal -COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler +#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler +#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell +#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell +#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal +#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal +#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta +#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 @@ -46,33 +47,33 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + $(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + $(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + $(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) cuda110: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + $(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) cuda11x: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + $(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) env: @echo "ENVIRONMENT" @echo "============================" @echo "NVCC path: $(NVCC)" - @echo "GPP path: $(GPP)" + @echo "GPP path: $(GPP) VERSION: `$(GPP) --version | head -n 1`" @echo "CUDA_HOME: $(CUDA_HOME)" @echo "CONDA_PREFIX: $(CONDA_PREFIX)" @echo "PATH: $(PATH)" diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9fe1345..44116cc 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -19,6 +19,7 @@ str2optimizer32bit = {} str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) +str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) @@ -33,6 +34,7 @@ str2optimizer8bit_blockwise = {} str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16) str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16) str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16) +str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16) optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0] diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 92c83b1..5e73414 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,8 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .adam import Adam, Adam8bit, Adam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit from .sgd import SGD, SGD8bit, SGD32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lamb import LAMB, LAMB8bit, LAMB32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .optimizer import GlobalOptimManager diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py new file mode 100644 index 0000000..84ade3c --- /dev/null +++ b/bitsandbytes/optim/adagrad.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from bitsandbytes.optim.optimizer import Optimizer1State + +torch.optim.Adagrad + +class Adagrad(Optimizer1State): + def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, + optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if initial_accumulator_value != 0.0: + raise ValueError('Initial accumulator value != 0.0 not supported!') + if lr_decay != 0.0: + raise ValueError('Lr Decay != 0.0 not supported!') + super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class Adagrad8bit(Optimizer1State): + def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, + optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if initial_accumulator_value != 0.0: + raise ValueError('Initial accumulator value != 0.0 not supported!') + if lr_decay != 0.0: + raise ValueError('Lr Decay != 0.0 not supported!') + assert block_wise + super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class Adagrad32bit(Optimizer1State): + def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, + optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if initial_accumulator_value != 0.0: + raise ValueError('Initial accumulator value != 0.0 not supported!') + if lr_decay != 0.0: + raise ValueError('Lr Decay != 0.0 not supported!') + super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index f3e5e81..ed1b9f0 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -33,7 +33,6 @@ class Adam32bit(Optimizer2State): weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) - class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py new file mode 100644 index 0000000..7761f3b --- /dev/null +++ b/bitsandbytes/optim/adamw.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from bitsandbytes.optim.optimizer import Optimizer2State +import bitsandbytes.functional as F + +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW, self).__init__('adam', params, lr, betas, eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class AdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class AdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 4b70b5c..cfbd72e 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -242,8 +242,9 @@ class Optimizer2State(Optimizer8bit): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if isinstance(betas, str): - betas = eval(betas) - print(betas, 'parsed') + # format: '(beta1, beta2)' + betas = betas.replace('(', '').replace(')', '').strip().split(',') + betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d8dfee1..d0aabff 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } break; } @@ -790,6 +793,11 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; } } @@ -884,6 +892,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p, s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; } } } @@ -1653,6 +1665,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; } } @@ -1688,6 +1703,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; } } } @@ -1738,6 +1757,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ @@ -1747,6 +1768,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \ @@ -1862,3 +1885,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/ops.cu b/csrc/ops.cu index 182d6e6..9691241 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -199,6 +199,8 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); @@ -240,6 +242,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -274,6 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g break; case MOMENTUM: case RMSPROP: + case ADAGRAD: blocks = n/BLOCKSIZE_1STATE; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 465b4a4..1bc13fb 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -36,6 +36,7 @@ typedef enum Optimizer_t MOMENTUM = 1, RMSPROP = 2, LARS = 3, + ADAGRAD = 4, } Optimizer_t; diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 7d5e654..e0b0d59 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -29,6 +29,8 @@ MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(adagrad, ADAGRAD, float, 32) +MAKE_FUNC32(adagrad, ADAGRAD, half, 16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -62,6 +64,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } @@ -102,6 +106,8 @@ extern "C" MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(adagrad, float, 32) + MAKE_CFUNC32(adagrad, half, 16) #define MAKE_CFUNC8(name, gtype, gbits) \ void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -135,6 +141,8 @@ extern "C" MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } diff --git a/tests/test_optim.py b/tests/test_optim.py index fc2456f..5464043 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -34,11 +34,13 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW) str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) +str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False)) str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) @@ -46,28 +48,34 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_ str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True)) str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) +str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True)) str2statenames = {} str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] +str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['rmsprop'] = [('square_avg', 'state1')] +str2statenames['adagrad'] = [('sum', 'state1')] str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] +str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] +str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] +optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -82,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) if gtype == torch.float32: - atol, rtol = 1e-6, 1e-5 + atol, rtol = 2e-6, 1e-5 else: atol, rtol = 1e-4, 1e-3 @@ -197,7 +205,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] +optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -384,3 +392,18 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): #assert s < 3.9 + +def test_str_betas(): + betas = (0.80, 0.95) + strbetas = '(0.80, 0.95)' + + layer = torch.nn.Linear(10, 10) + + base = bnb.optim.Adam(layer.parameters(), betas=betas) + strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas) + assert base.defaults['betas'][0] == 0.8 + assert base.defaults['betas'][1] == 0.95 + assert strbase.defaults['betas'][0] == 0.8 + assert strbase.defaults['betas'][1] == 0.95 + + |