From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- csrc/kernels.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'csrc/kernels.cuh') diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 06ae1e4..0a3676c 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -27,7 +27,7 @@ template __global__ void kOptimizer32bit2State(T* g, T* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, @@ -39,7 +39,7 @@ template __global__ void kOptimizer32bit1State(T* g, T* p, float* state1, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, const int n); + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void @@ -90,7 +90,7 @@ template __global__ voi T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float eps, const int step, const float lr, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n); + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizerStatic8bit1StateBlockwise( T* p, T* __restrict__ const g, unsigned char* state1, @@ -99,7 +99,7 @@ template __global__ voi float* __restrict__ const quantiles1, float* absmax1, float weight_decay, - const float gnorm_scale, const int n); + const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); -- cgit v1.2.3