summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu25
1 files changed, 25 insertions, 0 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index d8dfee1..d0aabff 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+
+ if(weight_decay > 0.0f)
+ p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
break;
}
@@ -790,6 +793,11 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
+ case ADAGRAD:
+ s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
+ s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
+ s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
+ break;
}
}
@@ -884,6 +892,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
break;
+ case ADAGRAD:
+ s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
+ p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
+ break;
}
}
}
@@ -1653,6 +1665,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
+ case ADAGRAD:
+ s1_vals[j] = s1_vals[j] + (g_val*g_val);
+ break;
}
}
@@ -1688,6 +1703,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
+ case ADAGRAD:
+ g_val = g_vals[j];
+ p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
+ break;
}
}
}
@@ -1738,6 +1757,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
+MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
+MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
@@ -1747,6 +1768,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
+MAKE_Optimizer32bit1State(ADAGRAD, half)
+MAKE_Optimizer32bit1State(ADAGRAD, float)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
@@ -1862,3 +1885,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
+MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
+MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)