summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <dettmers@cs.washington.edu>2021-10-20 19:15:47 -0700
committerTim Dettmers <dettmers@cs.washington.edu>2021-10-20 19:15:47 -0700
commita6eae2e7f2bf03f268fcb6b055201ff6827684c4 (patch)
treed2f72792251c9feaef1cf9dcddc3c79e6312a93a /csrc
parentbb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 (diff)
Added skip_zeros; tests are passing.
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu122
-rw-r--r--csrc/pythonInterface.c2
2 files changed, 71 insertions, 53 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index f8f7b62..d1fa253 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -715,9 +715,12 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
switch(OPTIMIZER)
{
case ADAM:
- s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
- s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
- p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
+ s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
+ p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+ }
break;
}
}
@@ -865,21 +868,24 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
- switch(OPTIMIZER)
- {
- case MOMENTUM:
- if(step == 1)
- s1_vals[j] = (float)g_vals[j];
- else
- s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
-
- p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
- break;
- case RMSPROP:
- s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
- p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
- break;
- }
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = (float)g_vals[j];
+ else
+ s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
+
+ p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
+ p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
+ break;
+ }
+ }
}
__syncthreads();
@@ -1469,11 +1475,14 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
- s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
- s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
+ s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
- s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
- s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
+ s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
+ s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
+ }
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
@@ -1509,9 +1518,12 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
- g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
- if(weight_decay > 0.0f)
- g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
+ if(weight_decay > 0.0f)
+ g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay));
+ }
}
// store: 0.85/1.44 -> 2.48/1.57
@@ -1623,23 +1635,26 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
- if(weight_decay > 0.0f)
- g_val += ((float)p_vals[j])*weight_decay;
-
- s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
-
- switch(OPTIMIZER)
- {
- case MOMENTUM:
- if(step == 1)
- s1_vals[j] = g_val;
- else
- s1_vals[j] = (s1_vals[j]*beta1) + g_val;
- break;
- case RMSPROP:
- s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
- break;
- }
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ if(weight_decay > 0.0f)
+ g_val += ((float)p_vals[j])*weight_decay;
+
+ s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
+
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ if(step == 1)
+ s1_vals[j] = g_val;
+ else
+ s1_vals[j] = (s1_vals[j]*beta1) + g_val;
+ break;
+ case RMSPROP:
+ s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
+ break;
+ }
+ }
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
}
@@ -1662,16 +1677,19 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
- switch(OPTIMIZER)
- {
- case MOMENTUM:
- p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
- break;
- case RMSPROP:
- g_val = g_vals[j];
- p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
- break;
- }
+ if(!skip_zeros || (skip_zeros && g_vals[j] != (T)0.0))
+ {
+ switch(OPTIMIZER)
+ {
+ case MOMENTUM:
+ p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
+ break;
+ case RMSPROP:
+ g_val = g_vals[j];
+ p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
+ break;
+ }
+ }
}
// store: 0.85/1.44 -> 2.48/1.57
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 67bf2e5..7d5e654 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -110,7 +110,7 @@ extern "C"
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
- float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
+ float weight_decay, float gnorm_scale, int n) \
{ \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \