From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- csrc/kernels.cu | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 66a2c99..f8f7b62 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -654,7 +654,7 @@ __launch_bounds__(TH, 1) __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -809,7 +809,7 @@ __launch_bounds__(TH, 1) __global__ void kOptimizer32bit1State(T *g, T *p, float *state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n) + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) { const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); @@ -1383,7 +1383,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, float weight_decay, - const float gnorm_scale, const int n) + const float gnorm_scale, const bool skip_zeros, const int n) { //const int n_full = n + (n%BLOCK_SIZE); @@ -1555,7 +1555,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char float* __restrict__ const quantiles1, float* absmax1, float weight_decay, - const float gnorm_scale, const int n) + const float gnorm_scale, const bool skip_zeros, const int n) { //const int n_full = n + (n%BLOCK_SIZE); @@ -1723,7 +1723,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ - const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); \ + const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) @@ -1740,9 +1740,9 @@ MAKE_PreconditionOptimizer32bit2State(ADAM, half) MAKE_PreconditionOptimizer32bit2State(ADAM, float) template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const int n); + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -1825,7 +1825,7 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise