summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-11-28 21:18:11 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2021-11-28 21:18:11 -0800
commit2f8083bd8b084290f888fe59b329d98ebd6dd468 (patch)
treeda534579bd762e93cd42b69a5e14c36f4b643979 /csrc
parentca2078a697ae3adfb84255ae398f79623dc4ea2a (diff)
Added AdamW. #10 #13
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu3
1 files changed, 3 insertions, 0 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 56f6a76..d0aabff 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
+
+ if(weight_decay > 0.0f)
+ p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
break;
}