From 2f8083bd8b084290f888fe59b329d98ebd6dd468 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 28 Nov 2021 21:18:11 -0800 Subject: Added AdamW. #10 #13 --- csrc/kernels.cu | 3 +++ 1 file changed, 3 insertions(+) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 56f6a76..d0aabff 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } break; } -- cgit v1.2.3