summaryrefslogtreecommitdiff
path: root/csrc/ops.cu
diff options
context:
space:
mode:
authorTim Dettmers <dettmers@g3036.hyak.local>2021-11-10 15:10:02 -0800
committerTim Dettmers <dettmers@g3036.hyak.local>2021-11-10 15:10:02 -0800
commit8b3c0f355c779170d55a1975df981df9e53b59fa (patch)
tree0ebc5f8e869fb02e7dec90f809fbf07d778f9aca /csrc/ops.cu
parent22b2877c7f8277317a073ea7cf49231d33fe79fd (diff)
Added adagrad with tests (no clipping).
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r--csrc/ops.cu8
1 files changed, 8 insertions, 0 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 182d6e6..9691241 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -199,6 +199,8 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
break;
case MOMENTUM:
case RMSPROP:
+ case ADAGRAD:
+
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
@@ -240,6 +242,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
break;
case MOMENTUM:
case RMSPROP:
+ case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
@@ -274,6 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
break;
case MOMENTUM:
case RMSPROP:
+ case ADAGRAD:
blocks = n/BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
@@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
+MAKE_optimizer32bit(ADAGRAD, half)
+MAKE_optimizer32bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
@@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
+MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
+MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);