From 8b3c0f355c779170d55a1975df981df9e53b59fa Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 10 Nov 2021 15:10:02 -0800 Subject: Added adagrad with tests (no clipping). --- csrc/kernels.cu | 22 ++++++++++++++++++++++ csrc/ops.cu | 8 ++++++++ csrc/ops.cuh | 1 + csrc/pythonInterface.c | 8 ++++++++ 4 files changed, 39 insertions(+) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d8dfee1..56f6a76 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -790,6 +790,11 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; } } @@ -884,6 +889,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p, s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; } } } @@ -1653,6 +1662,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char case RMSPROP: s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; } } @@ -1688,6 +1700,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char g_val = g_vals[j]; p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; } } } @@ -1738,6 +1754,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) #define MAKE_Optimizer32bit1State(oname, gtype) \ template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ @@ -1747,6 +1765,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half) MAKE_Optimizer32bit1State(MOMENTUM, float) MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ @@ -1862,3 +1882,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/ops.cu b/csrc/ops.cu index 182d6e6..9691241 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -199,6 +199,8 @@ template void optimizer32bit(T* g, T* p, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); @@ -240,6 +242,7 @@ template void optimizerStatic8bit(T* p, T* g, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -274,6 +277,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g break; case MOMENTUM: case RMSPROP: + case ADAGRAD: blocks = n/BLOCKSIZE_1STATE; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 465b4a4..1bc13fb 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -36,6 +36,7 @@ typedef enum Optimizer_t MOMENTUM = 1, RMSPROP = 2, LARS = 3, + ADAGRAD = 4, } Optimizer_t; diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 7d5e654..e0b0d59 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -29,6 +29,8 @@ MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(adagrad, ADAGRAD, float, 32) +MAKE_FUNC32(adagrad, ADAGRAD, half, 16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -62,6 +64,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -102,6 +106,8 @@ extern "C" MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(adagrad, float, 32) + MAKE_CFUNC32(adagrad, half, 16) #define MAKE_CFUNC8(name, gtype, gbits) \ void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -135,6 +141,8 @@ extern "C" MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } -- cgit v1.2.3