summaryrefslogtreecommitdiff
path: root/csrc/kernels.cuh
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/kernels.cuh')
-rw-r--r--csrc/kernels.cuh8
1 files changed, 4 insertions, 4 deletions
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 06ae1e4..0a3676c 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -27,7 +27,7 @@ template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
- const int step, const float lr, const float gnorm_scale, const int n);
+ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
@@ -39,7 +39,7 @@ template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay,
- const int step, const float lr, const float gnorm_scale, const int n);
+ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER>
__global__ void
@@ -90,7 +90,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
- float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n);
+ float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1,
@@ -99,7 +99,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
float* __restrict__ const quantiles1,
float* absmax1,
float weight_decay,
- const float gnorm_scale, const int n);
+ const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);