From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- bitsandbytes/functional.py | 25 +++++++++++++++++++------ bitsandbytes/optim/optimizer.py | 14 ++++++++++---- csrc/kernels.cu | 18 +++++++++--------- csrc/kernels.cuh | 8 ++++---- csrc/ops.cu | 16 ++++++++-------- csrc/ops.cuh | 5 +++-- csrc/pythonInterface.c | 18 +++++++++--------- tests/test_optim.py | 24 ++++++++++++++++++++++++ 8 files changed, 86 insertions(+), 42 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65c697d..48ab40c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -337,7 +337,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten beta1: float, eps: float, step: int, lr: float, state2: Tensor=None, beta2: float=0.0, weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: + unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None: ''' Performs an inplace optimizer update with one or two optimizer states. @@ -369,6 +369,12 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten Optimizer beta2. gnorm_scale : float The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. + skip_zeros : bool + Whether to skip zero-valued gradients or not (default: False). ''' param_norm = 0.0 @@ -381,11 +387,11 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten if g.dtype == torch.float32 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') @@ -439,6 +445,10 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten Max value for the next Adam update of the second state. gnorm_scale : float The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. ''' param_norm = 0.0 @@ -468,19 +478,22 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0) -> None: + absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0, + skip_zeros=False) -> None: if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 6743c15..25512b1 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer): config['percentile_clipping'] = self.args.percentile_clipping config['block_wise'] = self.args.block_wise config['max_unorm'] = self.args.max_unorm + config['skip_zeros'] = self.args.skip_zeros if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) @@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer): class Optimizer2State(Optimizer8bit): def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, + skip_zeros=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit): args['percentile_clipping'] = percentile_clipping args['block_wise'] = block_wise args['max_unorm'] = max_unorm + args['skip_zeros'] = skip_zeros self.args = MockArgs(args) else: @@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit): class Optimizer1State(Optimizer8bit): def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, + skip_zeros=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit): args['percentile_clipping'] = percentile_clipping args['block_wise'] = block_wise args['max_unorm'] = max_unorm + args['skip_zeros'] = skip_zeros self.args = MockArgs(args) else: @@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit): if state['state1'].dtype == torch.float: F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], None, 0.0, config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], + skip_zeros=False) elif state['state1'].dtype == torch.uint8 and not config['block_wise']: F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], @@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit): F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], config['eps'], step, config['lr'], state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale) + config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 66a2c99..f8f7b62 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -654,7 +654,7 @@ __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -809,7 +809,7 @@ __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1383,7 +1383,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, float weight_decay, - const float gnorm_scale, const int n) + const float gnorm_scale, const bool skip_zeros, const int n) { //const int n_full = n + (n%BLOCK_SIZE); @@ -1555,7 +1555,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char float* __restrict__ const quantiles1, float* absmax1, float weight_decay, - const float gnorm_scale, const int n) + const float gnorm_scale, const bool skip_zeros, const int n) { //const int n_full = n + (n%BLOCK_SIZE); @@ -1723,7 +1723,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); \ + const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -1740,9 +1740,9 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half) MAKE_PreconditionOptimizer32bit2State(ADAM, float) template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -1825,7 +1825,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, @@ -39,7 +39,7 @@ template __global__ void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void @@ -90,7 +90,7 @@ template __global__ voi T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n); + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -99,7 +99,7 @@ template __global__ voi float* __restrict__ const quantiles1, float* absmax1, float weight_decay, - const float gnorm_scale, const int n); + const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index d460ab1..182d6e6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -181,7 +181,7 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { int blocks = n/4096; blocks = n % 4096 == 0 ? blocks : blocks + 1; @@ -194,7 +194,7 @@ template void optimizer32bit(T* g, T* p, kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -206,7 +206,7 @@ template void optimizer32bit(T* g, T* p, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -259,7 +259,7 @@ template void optimizerStatic8bit(T* p, T* g, template void optimizerStatic8bitBlockwise(T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) { int blocks = 0; @@ -269,7 +269,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g blocks = n/BLOCKSIZE_2STATE; blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -277,7 +277,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g blocks = n/BLOCKSIZE_1STATE; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, n); + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -313,7 +313,7 @@ template void dequantizeBlockwise(float *code, unsigned char *A, float *a template void optimizer32bit(gtype* g, gtype* p, \ float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) @@ -342,7 +342,7 @@ MAKE_optimizerStatic8bit(RMSPROP, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index e6033cb..465b4a4 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -49,7 +49,7 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, float weight_decay, - int step, float lr, const float gnorm_scale, int n); + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, @@ -62,7 +62,8 @@ template void optimizerStatic8bit(T* p, T* g, unsigne template void optimizerStatic8bitBlockwise(T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index eacb849..67bf2e5 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,8 +20,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate void fname##32bit_g##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, float gnorm_scale, const int n) \ -{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \ + const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ +{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) @@ -53,8 +53,8 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\ -{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ +{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ MAKE_BLOCKWISE8(adam, ADAM, half, 16) MAKE_BLOCKWISE8(adam, ADAM, float, 32) @@ -93,8 +93,8 @@ extern "C" void c##name##32bit_g##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n) \ - { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \ + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ + { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_CFUNC32(adam, float, 32) MAKE_CFUNC32(adam, half, 16) @@ -110,7 +110,7 @@ extern "C" float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, float gnorm_scale, int n) \ + float weight_decay, float gnorm_scale, bool skip_zeros, int n) \ { \ name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ @@ -126,8 +126,8 @@ extern "C" #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \ - { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ + { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ MAKE_CBLOCKWISE8(adam, ADAM, half, 16) MAKE_CBLOCKWISE8(adam, ADAM, float, 32) diff --git a/tests/test_optim.py b/tests/test_optim.py index 4d67b08..fc2456f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -141,6 +141,7 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() + bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) @@ -155,6 +156,8 @@ def test_global_config(dim1, dim2, gtype): else: atol, rtol = 1e-4, 1e-3 + original_p2 = p2[mask].clone() + for i in range(50): g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 @@ -163,11 +166,32 @@ def test_global_config(dim1, dim2, gtype): p2.grad = g2 p3.grad = g3 + if i > 30 and i % 10 == 0: + g1.data[mask] = 0.0 + g2.data[mask] = 0.0 + p1.grad = g1 + p2.grad = g2 + original_p1 = p1[mask].clone() + original_p2 = p2[mask].clone() + og_s1 = adam2.state[p2]['state1'][mask].clone() + og_s2 = adam2.state[p2]['state2'][mask].clone() + og_s11 = adam2.state[p1]['state1'][mask].clone() + og_s21 = adam2.state[p1]['state2'][mask].clone() + adam2.step() assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8 + if i > 30 and i % 10 == 0: + torch.testing.assert_allclose(original_p2, p2[mask]) + torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1) + torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2) + assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel() + assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0 + assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0 + + dim1 = [1024] -- cgit v1.2.3