diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2021-11-28 21:18:11 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2021-11-28 21:18:11 -0800 |
commit | 2f8083bd8b084290f888fe59b329d98ebd6dd468 (patch) | |
tree | da534579bd762e93cd42b69a5e14c36f4b643979 /csrc | |
parent | ca2078a697ae3adfb84255ae398f79623dc4ea2a (diff) |
Added AdamW. #10 #13
Diffstat (limited to 'csrc')
-rw-r--r-- | csrc/kernels.cu | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 56f6a76..d0aabff 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } break; } |