From 8b3c0f355c779170d55a1975df981df9e53b59fa Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 10 Nov 2021 15:10:02 -0800 Subject: Added adagrad with tests (no clipping). --- csrc/ops.cu | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'csrc/ops.cu') diff --git a/csrc/ops.cu b/csrc/ops.cu index 182d6e6..9691241 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -199,6 +199,8 @@ template void optimizer32bit(T* g, T* p, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); @@ -240,6 +242,7 @@ template void optimizerStatic8bit(T* p, T* g, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -274,6 +277,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g break; case MOMENTUM: case RMSPROP: + case ADAGRAD: blocks = n/BLOCKSIZE_1STATE; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); -- cgit v1.2.3