summaryrefslogtreecommitdiff
path: root/csrc/ops.cuh
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/ops.cuh')
-rw-r--r--csrc/ops.cuh5
1 files changed, 3 insertions, 2 deletions
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index e6033cb..465b4a4 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -49,7 +49,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, float weight_decay,
- int step, float lr, const float gnorm_scale, int n);
+ int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
@@ -62,7 +62,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
- float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n);
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
+ bool skip_zeros, int n);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);