From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- csrc/ops.cu | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'csrc/ops.cu') diff --git a/csrc/ops.cu b/csrc/ops.cu index d460ab1..182d6e6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -181,7 +181,7 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { int blocks = n/4096; blocks = n % 4096 == 0 ? blocks : blocks + 1; @@ -194,7 +194,7 @@ template void optimizer32bit(T* g, T* p, kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -206,7 +206,7 @@ template void optimizer32bit(T* g, T* p, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -259,7 +259,7 @@ template void optimizerStatic8bit(T* p, T* g, template void optimizerStatic8bitBlockwise(T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) { int blocks = 0; @@ -269,7 +269,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g blocks = n/BLOCKSIZE_2STATE; blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -277,7 +277,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g blocks = n/BLOCKSIZE_1STATE; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, n); + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; } @@ -313,7 +313,7 @@ template void dequantizeBlockwise(float *code, unsigned char *A, float *a template void optimizer32bit(gtype* g, gtype* p, \ float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, const int n); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) @@ -342,7 +342,7 @@ MAKE_optimizerStatic8bit(RMSPROP, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM); -- cgit v1.2.3