summaryrefslogtreecommitdiff
path: root/csrc/ops.cu
diff options
context:
space:
mode:
authorTim Dettmers <dettmers@cs.washington.edu>2021-10-20 18:37:44 -0700
committerTim Dettmers <dettmers@cs.washington.edu>2021-10-20 18:37:44 -0700
commitbb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 (patch)
treea01ed945c348027480a9d0cefb6698dfd7259fb1 /csrc/ops.cu
parent8400b58cbbc06e0a434cfa71f76c2efd713473fc (diff)
Initial plumbing for skip_zeros.
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r--csrc/ops.cu16
1 files changed, 8 insertions, 8 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index d460ab1..182d6e6 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -181,7 +181,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
- const int step, const float lr, const float gnorm_scale, const int n)
+ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
int blocks = n/4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
@@ -194,7 +194,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
- kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
+ kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@@ -206,7 +206,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
- kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
+ kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
@@ -259,7 +259,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
- float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
int blocks = 0;
@@ -269,7 +269,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
blocks = n/BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
- quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n);
+ quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@@ -277,7 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
blocks = n/BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
- quantiles1, absmax1, weight_decay, gnorm_scale, n);
+ quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
@@ -313,7 +313,7 @@ template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *a
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
- const int step, const float lr, const float gnorm_scale, const int n);
+ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
@@ -342,7 +342,7 @@ MAKE_optimizerStatic8bit(RMSPROP, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
- float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \
+ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM);