summaryrefslogtreecommitdiff
path: root/csrc/pythonInterface.c
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/pythonInterface.c')
-rw-r--r--csrc/pythonInterface.c18
1 files changed, 9 insertions, 9 deletions
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index eacb849..67bf2e5 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -20,8 +20,8 @@ void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimate
void fname##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
- const int step, const float lr, float gnorm_scale, const int n) \
-{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
+ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
+{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
@@ -53,8 +53,8 @@ MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
- float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\
-{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
+{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
MAKE_BLOCKWISE8(adam, ADAM, half, 16)
MAKE_BLOCKWISE8(adam, ADAM, float, 32)
@@ -93,8 +93,8 @@ extern "C"
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
- const int step, const float lr, const float gnorm_scale, const int n) \
- { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
+ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
+ { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, 32)
MAKE_CFUNC32(adam, half, 16)
@@ -110,7 +110,7 @@ extern "C"
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
- float weight_decay, float gnorm_scale, int n) \
+ float weight_decay, float gnorm_scale, bool skip_zeros, int n) \
{ \
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
@@ -126,8 +126,8 @@ extern "C"
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
- float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \
- { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
+ { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)