From a6eae2e7f2bf03f268fcb6b055201ff6827684c4 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 19:15:47 -0700 Subject: Added skip_zeros; tests are passing. --- csrc/kernels.cu | 122 ++++++++++++++++++++++++++++--------------------- csrc/pythonInterface.c | 2 +- 2 files changed, 71 insertions(+), 53 deletions(-) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index f8f7b62..d1fa253 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p, switch(OPTIMIZER) { case ADAM: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); - s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); - p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0)) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + } break; } } @@ -865,21 +868,24 @@ __global__ void kOptimizer32bit1State(T *g, T *p, # pragma unroll 4 for(unsigned int j = 0; j < NUM_PER_THREAD; j++) { - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = (float)g_vals[j]; - else - s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); - - p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); - p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); - break; - } + if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0)) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + } + } } __syncthreads(); @@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char { g_val = float(g_vals[j]); g_val *= gnorm_scale; - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0)) + { + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; - s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); @@ -1509,9 +1518,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); - if(weight_decay > 0.0f) - g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay)); + if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0)) + { + g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if(weight_decay > 0.0f) + g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay)); + } } // store: 0.85/1.44 -> 2.48/1.57 @@ -1623,23 +1635,26 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char { g_val = float(g_vals[j]); g_val *= gnorm_scale; - if(weight_decay > 0.0f) - g_val += ((float)p_vals[j])*weight_decay; - - s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; - - switch(OPTIMIZER) - { - case MOMENTUM: - if(step == 1) - s1_vals[j] = g_val; - else - s1_vals[j] = (s1_vals[j]*beta1) + g_val; - break; - case RMSPROP: - s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); - break; - } + if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0)) + { + if(weight_decay > 0.0f) + g_val += ((float)p_vals[j])*weight_decay; + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); } @@ -1662,16 +1677,19 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - switch(OPTIMIZER) - { - case MOMENTUM: - p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); - break; - case RMSPROP: - g_val = g_vals[j]; - p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); - break; - } + if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0)) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } } // store: 0.85/1.44 -> 2.48/1.57 diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 67bf2e5..7d5e654 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -110,7 +110,7 @@ extern "C" float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ - float weight_decay, float gnorm_scale, bool skip_zeros, int n) \ + float weight_decay, float gnorm_scale, int n) \ { \ name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ -- cgit v1.2.3