From 8258b4364a21a4da2572cb644d0926080c3268da Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 1 Jul 2022 17:16:10 +0300 Subject: Add a CPU-only build option --- csrc/common.cpp | 39 +++++ csrc/common.h | 23 +++ csrc/cpu_ops.cpp | 57 +++++++ csrc/cpu_ops.h | 9 + csrc/ops.cu | 451 ++++++++++++++++++++----------------------------- csrc/ops.cuh | 10 -- csrc/pythonInterface.c | 118 +++++++------ 7 files changed, 377 insertions(+), 330 deletions(-) create mode 100644 csrc/common.cpp create mode 100644 csrc/common.h create mode 100644 csrc/cpu_ops.cpp create mode 100644 csrc/cpu_ops.h (limited to 'csrc') diff --git a/csrc/common.cpp b/csrc/common.cpp new file mode 100644 index 0000000..972602b --- /dev/null +++ b/csrc/common.cpp @@ -0,0 +1,39 @@ +#include +#include + +void *quantize_block(void *arguments) { + // 1. find absmax in block + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + // 4. check minimal distance + // 5. store index + + struct quantize_block_args *args = (quantize_block_args *) arguments; + + // 1. find absmax in block + float absmax_block = -FLT_MAX; + for (int i = args->block_idx; i < args->block_end; i++) + absmax_block = fmax(absmax_block, fabs(args->A[i])); + + args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block; + + for (int i = args->block_idx; i < args->block_end; i++) { + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + float normed_value = args->A[i] / absmax_block; + int idx = args->bin_searcher->scalar(normed_value); + + // 4. check minimal distance + // The binary search returns always the value to the left, which might not be the closest value + if (idx < 255) { + float dist_left = fabs(normed_value - (args->code[idx])); + float dist_right = fabs(normed_value - (args->code[idx + 1])); + if (dist_right < dist_left) { idx += 1; } + } + + // 5. store index + args->out[i] = (unsigned char) idx; + } + + return NULL; +} diff --git a/csrc/common.h b/csrc/common.h new file mode 100644 index 0000000..35f2463 --- /dev/null +++ b/csrc/common.h @@ -0,0 +1,23 @@ +#include + +#ifndef common +#define common + +using namespace BinSearch; + +struct quantize_block_args { + BinAlgo *bin_searcher; + float *code; + float *A; + float *absmax; + unsigned char *out; + int block_end; + int block_idx; + int threadidx; +}; + +#define BLOCK_SIZE 4096 + +void *quantize_block(void *arguments); + +#endif \ No newline at end of file diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp new file mode 100644 index 0000000..11a2615 --- /dev/null +++ b/csrc/cpu_ops.cpp @@ -0,0 +1,57 @@ +#include +#include +#include + +using namespace BinSearch; + +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) { + for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { + int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; + int block_end = block_idx + valid_items; + for (int i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE]; + } +} + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) { + + // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below + code[0] = -1.0f; + + int num_blocks = n / BLOCK_SIZE; + num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; + + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks); + struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *)); + + for (int i = 0; i < num_blocks; i++) + args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + + const uint32 elements_code = 256; + BinAlgo bin_searcher(code, elements_code); + + for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { + int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; + int block_end = block_idx + valid_items; + + struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE]; + arg->bin_searcher = &bin_searcher; + arg->code = code; + arg->A = A; + arg->absmax = absmax; + arg->out = out; + arg->block_end = block_end; + arg->block_idx = block_idx; + arg->threadidx = block_idx / BLOCK_SIZE; + + pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg); + } + + for (int i = 0; i < num_blocks; i++) + int err = pthread_join(threads[i], NULL); + + free(threads); + for (int i = 0; i < num_blocks; i++) + free(args[i]); + free(args); +} \ No newline at end of file diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h new file mode 100644 index 0000000..57145a9 --- /dev/null +++ b/csrc/cpu_ops.h @@ -0,0 +1,9 @@ +#ifndef BITSANDBYTES_CPU_OPS_H +#define BITSANDBYTES_CPU_OPS_H + + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n); + +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n); + +#endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 9691241..464ea2e 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -8,251 +8,141 @@ #include #include #include +#include using namespace BinSearch; using std::cout; using std::endl; -#define BLOCK_SIZE 4096 - -struct quantize_block_args -{ - BinAlgo *bin_searcher; - float *code; - float *A; - float *absmax; - unsigned char *out; - int block_end; - int block_idx; - int threadidx; -}; - -void *quantize_block(void *arguments) -{ - // 1. find absmax in block - // 2. divide input value by absmax to normalize into [-1.0, 1.0] - // 3. do binary search to find the closest value - // 4. check minimal distance - // 5. store index - - struct quantize_block_args *args = (quantize_block_args*)arguments; - - // 1. find absmax in block - float absmax_block = -FLT_MAX; - for (int i = args->block_idx; i < args->block_end; i++) - absmax_block = fmax(absmax_block, fabs(args->A[i])); - - args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block; - - for (int i = args->block_idx; i < args->block_end; i++) - { - // 2. divide input value by absmax to normalize into [-1.0, 1.0] - // 3. do binary search to find the closest value - float normed_value = args->A[i]/absmax_block; - int idx = args->bin_searcher->scalar(normed_value); - - // 4. check minimal distance - // The binary search returns always the value to the left, which might not be the closest value - if(idx < 255) - { - float dist_left = fabs(normed_value-(args->code[idx])); - float dist_right = fabs(normed_value-(args->code[idx+1])); - if(dist_right < dist_left){ idx+=1; } - } - - // 5. store index - args->out[i] = (unsigned char)idx; - } - - return NULL; +void histogramScatterAdd2D(float *histogram, int *index1, int *index2, float *src, int maxidx1, int n) { + int threads = 512; + int blocks = n / threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) -{ - - // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below - code[0] = -1.0f; - - int num_blocks = n/BLOCK_SIZE; - num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; - - pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks); - struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*)); - - for(int i = 0; i < num_blocks; i++) - args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args)); - - const uint32 elements_code = 256; - BinAlgo bin_searcher(code, elements_code); - - for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE) - { - int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - - struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx/BLOCK_SIZE; - - pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg); - } - - for(int i = 0; i < num_blocks; i++) - int err = pthread_join(threads[i], NULL); - - free(threads); - for(int i = 0; i < num_blocks; i++) - free(args[i]); - free(args); -} - - -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) -{ - for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE) - { - int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - for (int i = block_idx; i < block_end; i++) - out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE]; - } +template +void estimateQuantiles(T *A, float *code, float offset, int n) { + int blocks = n / 4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(code, 0, 256 * sizeof(float))); + kEstimateQuantiles < T ><<>>(A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) -{ - int threads = 512; - int blocks = n/threads; - blocks = n % threads == 0 ? blocks : blocks + 1; - kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void quantize(float *code, float *A, unsigned char *out, int n) { + int blocks = n / 1024; + blocks = n % 1024 == 0 ? blocks : blocks + 1; + kQuantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void estimateQuantiles(T *A, float *code, float offset, int n) -{ - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); - kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void dequantize(float *code, unsigned char *A, float *out, int n) { + int blocks = n / 1024; + blocks = n % 1024 == 0 ? blocks : blocks + 1; + kDequantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void quantize(float *code, float *A, unsigned char *out, int n) -{ - int blocks = n/1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kQuantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template +void quantizeBlockwise(float *code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) { + int blocks = n / 4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + kQuantizeBlockwise < T, 4096, 4, STOCHASTIC ><<>>(code, A, absmax, out, rand, rand_offset, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void dequantize(float *code, unsigned char *A, float *out, int n) -{ - int blocks = n/1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kDequantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template +void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { + int blocks = n / blocksize; + blocks = n % blocksize == 0 ? blocks : blocks + 1; + if (blocksize == 4096) + kDequantizeBlockwise < T, 4096, 1024, 4 ><<>>(code, A, absmax, out, n); + else if (blocksize == 2048) + kDequantizeBlockwise < T, 2048, 512, 4 ><<>>(code, A, absmax, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) -{ - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) -{ - int blocks = n/blocksize; - blocks = n % blocksize == 0 ? blocks : blocks + 1; - if(blocksize == 4096) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 2048) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - -template void optimizer32bit(T* g, T* p, - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) -{ - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - switch(OPTIMIZER) - { - case ADAM: - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - - if(max_unorm > 0.0f) - { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); - kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - - kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - } +template +void optimizer32bit(T *g, T *p, + float *state1, float *state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { + int blocks = n / 4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + switch (OPTIMIZER) { + case ADAM: + if (max_unorm > 0.0f) { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit2State < T, OPTIMIZER, 4096, + 8 ><<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + kOptimizer32bit2State < T, + OPTIMIZER ><<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + + if (max_unorm > 0.0f) { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); + kPreconditionOptimizer32bit1State < T, OPTIMIZER, 4096, + 8 ><<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + kOptimizer32bit1State < T, + OPTIMIZER ><<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } } -template void optimizerStatic8bit(T* p, T* g, - unsigned char* state1, unsigned char* state2, - float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, - float eps, int step, float lr, - float* quantiles1, float* quantiles2, - float* max1, float* max2, float* new_max1, float* new_max2, - float weight_decay, - const float gnorm_scale, int n) -{ - int blocks = n/4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - - if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } - - switch(OPTIMIZER) - { - case ADAM: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, - quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - default: - break; - } +template +void optimizerStatic8bit(T *p, T *g, + unsigned char *state1, unsigned char *state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float *quantiles1, float *quantiles2, + float *max1, float *max2, float *new_max1, float *new_max2, + float weight_decay, + const float gnorm_scale, int n) { + int blocks = n / 4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + + if (max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); } + + switch (OPTIMIZER) { + case ADAM: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); + CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit2State < T, + OPTIMIZER ><<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit2State < T, + OPTIMIZER ><<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); + kPreconditionOptimizerStatic8bit1State < T, + OPTIMIZER ><<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit1State < T, OPTIMIZER ><<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + default: + break; + } } #define BLOCKSIZE_2STATE 2048 @@ -260,42 +150,43 @@ template void optimizerStatic8bit(T* p, T* g, #define BLOCKSIZE_1STATE 2048 #define NUM_1STATE 8 -template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) -{ - - int blocks = 0; - switch(OPTIMIZER) - { - case ADAM: - blocks = n/BLOCKSIZE_2STATE; - blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - blocks = n/BLOCKSIZE_1STATE; - blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - } +template +void optimizerStatic8bitBlockwise(T *p, T *g, + unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr, + float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay, + const float gnorm_scale, bool skip_zeros, int n) { + + int blocks = 0; + switch (OPTIMIZER) { + case ADAM: + blocks = n / BLOCKSIZE_2STATE; + blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; + kOptimizerStatic8bit2StateBlockwise < T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE ><<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + blocks = n / BLOCKSIZE_1STATE; + blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; + kOptimizerStatic8bit1StateBlockwise < T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE ><<>>(p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } } - -template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) -{ - int blocks = n/2048; - blocks = n % 2048 == 0 ? blocks : blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); - kPercentileClipping<<>>(g, gnorm_vec, step, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template +void percentileClipping(T *g, float *gnorm_vec, int step, const int n) { + int blocks = n / 2048; + blocks = n % 2048 == 0 ? blocks : blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float))); + kPercentileClipping < T, 2048, 4 ><<>>(g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -304,13 +195,23 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, //============================================================== template void estimateQuantiles(half *A, float *code, float offset, int n); + template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void +quantizeBlockwise(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); + +template void +quantizeBlockwise(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); + +template void +quantizeBlockwise(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); + +template void +quantizeBlockwise(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); + template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); + template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ @@ -320,12 +221,19 @@ template void optimizer32bit(gtype* g, gtype* p, \ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); MAKE_optimizer32bit(ADAM, half) + MAKE_optimizer32bit(ADAM, float) + MAKE_optimizer32bit(MOMENTUM, half) + MAKE_optimizer32bit(MOMENTUM, float) + MAKE_optimizer32bit(RMSPROP, half) + MAKE_optimizer32bit(RMSPROP, float) + MAKE_optimizer32bit(ADAGRAD, half) + MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ @@ -338,11 +246,17 @@ template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char float weight_decay, \ const float gnorm_scale, int n); \ + MAKE_optimizerStatic8bit(ADAM, half) + MAKE_optimizerStatic8bit(ADAM, float) + MAKE_optimizerStatic8bit(MOMENTUM, half) + MAKE_optimizerStatic8bit(MOMENTUM, float) + MAKE_optimizerStatic8bit(RMSPROP, half) + MAKE_optimizerStatic8bit(RMSPROP, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ @@ -350,14 +264,23 @@ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + MAKE_optimizerStatic8bitBlockwise(half, ADAM); + MAKE_optimizerStatic8bitBlockwise(float, ADAM); + MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); + MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); + MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); + MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); + MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); + MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); -template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); -template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(float *g, float *gnorm_vec, int step, const int n); + +template void percentileClipping(half *g, float *gnorm_vec, int step, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 1bc13fb..8fb4cec 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -68,16 +68,6 @@ template void optimizerStatic8bitBlockwise(T* p, T* g template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n); -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n); - void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); #endif - - - - - - - diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index e0b0d59..229b7ed 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -3,7 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +#if BUILD_CUDA #include +#endif +#include // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to @@ -12,6 +15,7 @@ // UNMANGLED CALLS //=================================================================================== +#if BUILD_CUDA void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } @@ -34,15 +38,15 @@ MAKE_FUNC32(adagrad, ADAGRAD, half, 16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ - float *unorm, float max_unorm, float param_norm, \ + float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ MAKE_FUNC8(adam, ADAM, float, 32) @@ -78,39 +82,41 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +#endif extern "C" { - void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } - void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } - void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } - void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } - void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } - void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } - void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } - void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } - - void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - - #define MAKE_CFUNC32(name, gtype, gbits) \ - void c##name##32bit_g##gbits(gtype *g, gtype *p, \ - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - - MAKE_CFUNC32(adam, float, 32) - MAKE_CFUNC32(adam, half, 16) - MAKE_CFUNC32(momentum, float, 32) - MAKE_CFUNC32(momentum, half, 16) - MAKE_CFUNC32(rmsprop, float, 32) - MAKE_CFUNC32(rmsprop, half, 16) - MAKE_CFUNC32(adagrad, float, 32) - MAKE_CFUNC32(adagrad, half, 16) - - #define MAKE_CFUNC8(name, gtype, gbits) \ - void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ +#if BUILD_CUDA +void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } +void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } +void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } +void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } +void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } +void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } +void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } +void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } + +void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } +void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + +#define MAKE_CFUNC32(name, gtype, gbits) \ + void c##name##32bit_g##gbits(gtype *g, gtype *p, \ + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ + { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + +MAKE_CFUNC32(adam, float, 32) +MAKE_CFUNC32(adam, half, 16) +MAKE_CFUNC32(momentum, float, 32) +MAKE_CFUNC32(momentum, half, 16) +MAKE_CFUNC32(rmsprop, float, 32) +MAKE_CFUNC32(rmsprop, half, 16) +MAKE_CFUNC32(adagrad, float, 32) +MAKE_CFUNC32(adagrad, half, 16) + +#define MAKE_CFUNC8(name, gtype, gbits) \ + void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -118,40 +124,40 @@ extern "C" float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ - MAKE_CFUNC8(adam, float, 32) - MAKE_CFUNC8(adam, half, 16) - MAKE_CFUNC8(momentum, float, 32) - MAKE_CFUNC8(momentum, half, 16) - MAKE_CFUNC8(rmsprop, float, 32) - MAKE_CFUNC8(rmsprop, half, 16) +MAKE_CFUNC8(adam, float, 32) +MAKE_CFUNC8(adam, half, 16) +MAKE_CFUNC8(momentum, float, 32) +MAKE_CFUNC8(momentum, half, 16) +MAKE_CFUNC8(rmsprop, float, 32) +MAKE_CFUNC8(rmsprop, half, 16) - #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ +#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ - MAKE_CBLOCKWISE8(adam, ADAM, half, 16) - MAKE_CBLOCKWISE8(adam, ADAM, float, 32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) +MAKE_CBLOCKWISE8(adam, ADAM, half, 16) +MAKE_CBLOCKWISE8(adam, ADAM, float, 32) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) +MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) +MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) +MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) - void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } - void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } +void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } +void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } +void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } +#endif - void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } - void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } - - void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } +void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } +void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } } -- cgit v1.2.3 From 31ce1b3708751016bf5e14beff7ae0a99c975991 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 1 Jul 2022 17:36:30 +0300 Subject: Reduce diff --- csrc/common.h | 2 +- csrc/cpu_ops.cpp | 2 +- csrc/ops.cu | 290 ++++++++++++++++++++++++------------------------------- 3 files changed, 129 insertions(+), 165 deletions(-) (limited to 'csrc') diff --git a/csrc/common.h b/csrc/common.h index 35f2463..2f25a58 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -20,4 +20,4 @@ struct quantize_block_args { void *quantize_block(void *arguments); -#endif \ No newline at end of file +#endif diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 11a2615..89de52d 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -54,4 +54,4 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int for (int i = 0; i < num_blocks; i++) free(args[i]); free(args); -} \ No newline at end of file +} diff --git a/csrc/ops.cu b/csrc/ops.cu index 464ea2e..b2a1105 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -46,103 +46,100 @@ void dequantize(float *code, unsigned char *A, float *out, int n) { CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template -void quantizeBlockwise(float *code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) { - int blocks = n / 4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - kQuantizeBlockwise < T, 4096, 4, STOCHASTIC ><<>>(code, A, absmax, out, rand, rand_offset, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) +{ + int blocks = n/4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template -void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { - int blocks = n / blocksize; - blocks = n % blocksize == 0 ? blocks : blocks + 1; - if (blocksize == 4096) - kDequantizeBlockwise < T, 4096, 1024, 4 ><<>>(code, A, absmax, out, n); - else if (blocksize == 2048) - kDequantizeBlockwise < T, 2048, 512, 4 ><<>>(code, A, absmax, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +{ + int blocks = n/blocksize; + blocks = n % blocksize == 0 ? blocks : blocks + 1; + if(blocksize == 4096) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 2048) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template -void optimizer32bit(T *g, T *p, - float *state1, float *state2, float *unorm, float max_unorm, float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay, - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) { - int blocks = n / 4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - switch (OPTIMIZER) { - case ADAM: - if (max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); - kPreconditionOptimizer32bit2State < T, OPTIMIZER, 4096, - 8 ><<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - kOptimizer32bit2State < T, - OPTIMIZER ><<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - - if (max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); - kPreconditionOptimizer32bit1State < T, OPTIMIZER, 4096, - 8 ><<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - } - - kOptimizer32bit1State < T, - OPTIMIZER ><<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - } +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int blocks = n/4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } } -template -void optimizerStatic8bit(T *p, T *g, - unsigned char *state1, unsigned char *state2, - float *unorm, float max_unorm, float param_norm, - float beta1, float beta2, - float eps, int step, float lr, - float *quantiles1, float *quantiles2, - float *max1, float *max2, float *new_max1, float *new_max2, - float weight_decay, - const float gnorm_scale, int n) { - int blocks = n / 4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - - if (max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); } - - switch (OPTIMIZER) { - case ADAM: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); - CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float))); - kPreconditionOptimizerStatic8bit2State < T, - OPTIMIZER ><<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit2State < T, - OPTIMIZER ><<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); - kPreconditionOptimizerStatic8bit1State < T, - OPTIMIZER ><<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - kOptimizerStatic8bit1State < T, OPTIMIZER ><<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, - quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - default: - break; - } +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int blocks = n/4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + default: + break; + } } #define BLOCKSIZE_2STATE 2048 @@ -150,43 +147,42 @@ void optimizerStatic8bit(T *p, T *g, #define BLOCKSIZE_1STATE 2048 #define NUM_1STATE 8 -template -void optimizerStatic8bitBlockwise(T *p, T *g, - unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr, - float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay, - const float gnorm_scale, bool skip_zeros, int n) { - - int blocks = 0; - switch (OPTIMIZER) { - case ADAM: - blocks = n / BLOCKSIZE_2STATE; - blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit2StateBlockwise < T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE ><<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - case MOMENTUM: - case RMSPROP: - case ADAGRAD: - blocks = n / BLOCKSIZE_1STATE; - blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; - kOptimizerStatic8bit1StateBlockwise < T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE ><<>>(p, g, state1, beta1, beta2, eps, step, lr, - quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); - break; - } +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) +{ + + int blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + blocks = n/BLOCKSIZE_2STATE; + blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; + kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + blocks = n/BLOCKSIZE_1STATE; + blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } } -template -void percentileClipping(T *g, float *gnorm_vec, int step, const int n) { - int blocks = n / 2048; - blocks = n % 2048 == 0 ? blocks : blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float))); - kPercentileClipping < T, 2048, 4 ><<>>(g, gnorm_vec, step, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int blocks = n/2048; + blocks = n % 2048 == 0 ? blocks : blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + kPercentileClipping<<>>(g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -195,23 +191,13 @@ void percentileClipping(T *g, float *gnorm_vec, int step, const int n) { //============================================================== template void estimateQuantiles(half *A, float *code, float offset, int n); - template void estimateQuantiles(float *A, float *code, float offset, int n); -template void -quantizeBlockwise(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); - -template void -quantizeBlockwise(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); - -template void -quantizeBlockwise(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); - -template void -quantizeBlockwise(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n); - +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); - template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ @@ -221,19 +207,12 @@ template void optimizer32bit(gtype* g, gtype* p, \ const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); MAKE_optimizer32bit(ADAM, half) - MAKE_optimizer32bit(ADAM, float) - MAKE_optimizer32bit(MOMENTUM, half) - MAKE_optimizer32bit(MOMENTUM, float) - MAKE_optimizer32bit(RMSPROP, half) - MAKE_optimizer32bit(RMSPROP, float) - MAKE_optimizer32bit(ADAGRAD, half) - MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ @@ -246,17 +225,11 @@ template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char float weight_decay, \ const float gnorm_scale, int n); \ - MAKE_optimizerStatic8bit(ADAM, half) - MAKE_optimizerStatic8bit(ADAM, float) - MAKE_optimizerStatic8bit(MOMENTUM, half) - MAKE_optimizerStatic8bit(MOMENTUM, float) - MAKE_optimizerStatic8bit(RMSPROP, half) - MAKE_optimizerStatic8bit(RMSPROP, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ @@ -264,23 +237,14 @@ template void optimizerStatic8bitBlockwise(gtype* p, gtype* g unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ - MAKE_optimizerStatic8bitBlockwise(half, ADAM); - MAKE_optimizerStatic8bitBlockwise(float, ADAM); - MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); - MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); - MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); - MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); - MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); - MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); -template void percentileClipping(float *g, float *gnorm_vec, int step, const int n); - -template void percentileClipping(half *g, float *gnorm_vec, int step, const int n); +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); -- cgit v1.2.3 From 4d1d5b569f55dd613bea26714eb1ad931a10be35 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 1 Jul 2022 17:40:02 +0300 Subject: Reduce diff --- csrc/pythonInterface.c | 115 ++++++++++++++++++++++++------------------------- 1 file changed, 57 insertions(+), 58 deletions(-) (limited to 'csrc') diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 229b7ed..1f690c5 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -38,15 +38,15 @@ MAKE_FUNC32(adagrad, ADAGRAD, half, 16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ - float *unorm, float max_unorm, float param_norm, \ + float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ float* quantiles1, float* quantiles2, \ float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ MAKE_FUNC8(adam, ADAM, float, 32) @@ -86,37 +86,37 @@ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, floa extern "C" { -#if BUILD_CUDA -void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } -void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } -void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } -void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } -void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } -void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } -void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } -void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } - -void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } -void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - -#define MAKE_CFUNC32(name, gtype, gbits) \ - void c##name##32bit_g##gbits(gtype *g, gtype *p, \ - float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ - const float beta1, const float beta2, const float eps, const float weight_decay, \ - const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - -MAKE_CFUNC32(adam, float, 32) -MAKE_CFUNC32(adam, half, 16) -MAKE_CFUNC32(momentum, float, 32) -MAKE_CFUNC32(momentum, half, 16) -MAKE_CFUNC32(rmsprop, float, 32) -MAKE_CFUNC32(rmsprop, half, 16) -MAKE_CFUNC32(adagrad, float, 32) -MAKE_CFUNC32(adagrad, half, 16) - -#define MAKE_CFUNC8(name, gtype, gbits) \ - void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + if #BUILD_CUDA + void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } + void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } + void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } + void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } + void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } + void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } + void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } + void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } + + void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + + #define MAKE_CFUNC32(name, gtype, gbits) \ + void c##name##32bit_g##gbits(gtype *g, gtype *p, \ + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ + { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + + MAKE_CFUNC32(adam, float, 32) + MAKE_CFUNC32(adam, half, 16) + MAKE_CFUNC32(momentum, float, 32) + MAKE_CFUNC32(momentum, half, 16) + MAKE_CFUNC32(rmsprop, float, 32) + MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(adagrad, float, 32) + MAKE_CFUNC32(adagrad, half, 16) + + #define MAKE_CFUNC8(name, gtype, gbits) \ + void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -124,40 +124,39 @@ MAKE_CFUNC32(adagrad, half, 16) float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ - quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ -MAKE_CFUNC8(adam, float, 32) -MAKE_CFUNC8(adam, half, 16) -MAKE_CFUNC8(momentum, float, 32) -MAKE_CFUNC8(momentum, half, 16) -MAKE_CFUNC8(rmsprop, float, 32) -MAKE_CFUNC8(rmsprop, half, 16) + MAKE_CFUNC8(adam, float, 32) + MAKE_CFUNC8(adam, half, 16) + MAKE_CFUNC8(momentum, float, 32) + MAKE_CFUNC8(momentum, half, 16) + MAKE_CFUNC8(rmsprop, float, 32) + MAKE_CFUNC8(rmsprop, half, 16) -#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ + #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ -MAKE_CBLOCKWISE8(adam, ADAM, half, 16) -MAKE_CBLOCKWISE8(adam, ADAM, float, 32) -MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) -MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) -MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) -MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) -MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) -MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) + MAKE_CBLOCKWISE8(adam, ADAM, half, 16) + MAKE_CBLOCKWISE8(adam, ADAM, float, 32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) -void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } -void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } -void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } -#endif + void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } + void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } + void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } -void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } -void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } + #endif + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } + void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } } - -- cgit v1.2.3 From 575aa698fa53df2f5c584413aed7bf7714f86039 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 1 Jul 2022 17:41:48 +0300 Subject: Reduce diff --- csrc/ops.cu | 45 ++++++++++++++++++++------------------------- csrc/pythonInterface.c | 2 +- 2 files changed, 21 insertions(+), 26 deletions(-) (limited to 'csrc') diff --git a/csrc/ops.cu b/csrc/ops.cu index b2a1105..dbb50be 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -15,35 +15,30 @@ using namespace BinSearch; using std::cout; using std::endl; -void histogramScatterAdd2D(float *histogram, int *index1, int *index2, float *src, int maxidx1, int n) { - int threads = 512; - int blocks = n / threads; - blocks = n % threads == 0 ? blocks : blocks + 1; - kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - -template -void estimateQuantiles(T *A, float *code, float offset, int n) { - int blocks = n / 4096; - blocks = n % 4096 == 0 ? blocks : blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(code, 0, 256 * sizeof(float))); - kEstimateQuantiles < T ><<>>(A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void quantize(float *code, float *A, unsigned char *out, int n) { - int blocks = n / 1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kQuantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int blocks = n/4096; + blocks = n % 4096 == 0 ? blocks : blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); + kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void dequantize(float *code, unsigned char *A, float *out, int n) { - int blocks = n / 1024; - blocks = n % 1024 == 0 ? blocks : blocks + 1; - kDequantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int blocks = n/1024; + blocks = n % 1024 == 0 ? blocks : blocks + 1; + kQuantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 1f690c5..c2fed6b 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -86,7 +86,7 @@ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, floa extern "C" { - if #BUILD_CUDA + #if BUILD_CUDA void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } -- cgit v1.2.3 From 025824d29b38f6b981bbcea8a61bc23e7f2b3e02 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Fri, 1 Jul 2022 17:42:58 +0300 Subject: Reduce diff --- csrc/ops.cu | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'csrc') diff --git a/csrc/ops.cu b/csrc/ops.cu index dbb50be..40c185c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -41,6 +41,14 @@ void quantize(float *code, float *A, unsigned char *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } +void dequantize(float *code, unsigned char *A, float *out, int n) +{ + int blocks = n/1024; + blocks = n % 1024 == 0 ? blocks : blocks + 1; + kDequantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) { int blocks = n/4096; -- cgit v1.2.3