diff options
Diffstat (limited to 'csrc')
-rw-r--r-- | csrc/common.cpp | 39 | ||||
-rw-r--r-- | csrc/common.h | 23 | ||||
-rw-r--r-- | csrc/cpu_ops.cpp | 57 | ||||
-rw-r--r-- | csrc/cpu_ops.h | 9 | ||||
-rw-r--r-- | csrc/ops.cu | 122 | ||||
-rw-r--r-- | csrc/ops.cuh | 10 | ||||
-rw-r--r-- | csrc/pythonInterface.c | 11 |
7 files changed, 142 insertions, 129 deletions
diff --git a/csrc/common.cpp b/csrc/common.cpp new file mode 100644 index 0000000..972602b --- /dev/null +++ b/csrc/common.cpp @@ -0,0 +1,39 @@ +#include <common.h> +#include <float.h> + +void *quantize_block(void *arguments) { + // 1. find absmax in block + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + // 4. check minimal distance + // 5. store index + + struct quantize_block_args *args = (quantize_block_args *) arguments; + + // 1. find absmax in block + float absmax_block = -FLT_MAX; + for (int i = args->block_idx; i < args->block_end; i++) + absmax_block = fmax(absmax_block, fabs(args->A[i])); + + args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block; + + for (int i = args->block_idx; i < args->block_end; i++) { + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + float normed_value = args->A[i] / absmax_block; + int idx = args->bin_searcher->scalar(normed_value); + + // 4. check minimal distance + // The binary search returns always the value to the left, which might not be the closest value + if (idx < 255) { + float dist_left = fabs(normed_value - (args->code[idx])); + float dist_right = fabs(normed_value - (args->code[idx + 1])); + if (dist_right < dist_left) { idx += 1; } + } + + // 5. store index + args->out[i] = (unsigned char) idx; + } + + return NULL; +} diff --git a/csrc/common.h b/csrc/common.h new file mode 100644 index 0000000..2f25a58 --- /dev/null +++ b/csrc/common.h @@ -0,0 +1,23 @@ +#include <BinSearch.h> + +#ifndef common +#define common + +using namespace BinSearch; + +struct quantize_block_args { + BinAlgo<Scalar, float, Direct2> *bin_searcher; + float *code; + float *A; + float *absmax; + unsigned char *out; + int block_end; + int block_idx; + int threadidx; +}; + +#define BLOCK_SIZE 4096 + +void *quantize_block(void *arguments); + +#endif diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp new file mode 100644 index 0000000..89de52d --- /dev/null +++ b/csrc/cpu_ops.cpp @@ -0,0 +1,57 @@ +#include <BinSearch.h> +#include <pthread.h> +#include <common.h> + +using namespace BinSearch; + +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) { + for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { + int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; + int block_end = block_idx + valid_items; + for (int i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE]; + } +} + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) { + + // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below + code[0] = -1.0f; + + int num_blocks = n / BLOCK_SIZE; + num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; + + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks); + struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *)); + + for (int i = 0; i < num_blocks; i++) + args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + + const uint32 elements_code = 256; + BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code); + + for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { + int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; + int block_end = block_idx + valid_items; + + struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE]; + arg->bin_searcher = &bin_searcher; + arg->code = code; + arg->A = A; + arg->absmax = absmax; + arg->out = out; + arg->block_end = block_end; + arg->block_idx = block_idx; + arg->threadidx = block_idx / BLOCK_SIZE; + + pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg); + } + + for (int i = 0; i < num_blocks; i++) + int err = pthread_join(threads[i], NULL); + + free(threads); + for (int i = 0; i < num_blocks; i++) + free(args[i]); + free(args); +} diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h new file mode 100644 index 0000000..57145a9 --- /dev/null +++ b/csrc/cpu_ops.h @@ -0,0 +1,9 @@ +#ifndef BITSANDBYTES_CPU_OPS_H +#define BITSANDBYTES_CPU_OPS_H + + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n); + +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n); + +#endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 9691241..40c185c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -8,123 +8,13 @@ #include <cub/device/device_scan.cuh> #include <limits> #include <BinSearch.h> +#include <common.h> using namespace BinSearch; using std::cout; using std::endl; -#define BLOCK_SIZE 4096 - -struct quantize_block_args -{ - BinAlgo<Scalar, float, Direct2> *bin_searcher; - float *code; - float *A; - float *absmax; - unsigned char *out; - int block_end; - int block_idx; - int threadidx; -}; - -void *quantize_block(void *arguments) -{ - // 1. find absmax in block - // 2. divide input value by absmax to normalize into [-1.0, 1.0] - // 3. do binary search to find the closest value - // 4. check minimal distance - // 5. store index - - struct quantize_block_args *args = (quantize_block_args*)arguments; - - // 1. find absmax in block - float absmax_block = -FLT_MAX; - for (int i = args->block_idx; i < args->block_end; i++) - absmax_block = fmax(absmax_block, fabs(args->A[i])); - - args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block; - - for (int i = args->block_idx; i < args->block_end; i++) - { - // 2. divide input value by absmax to normalize into [-1.0, 1.0] - // 3. do binary search to find the closest value - float normed_value = args->A[i]/absmax_block; - int idx = args->bin_searcher->scalar(normed_value); - - // 4. check minimal distance - // The binary search returns always the value to the left, which might not be the closest value - if(idx < 255) - { - float dist_left = fabs(normed_value-(args->code[idx])); - float dist_right = fabs(normed_value-(args->code[idx+1])); - if(dist_right < dist_left){ idx+=1; } - } - - // 5. store index - args->out[i] = (unsigned char)idx; - } - - return NULL; -} - -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) -{ - - // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below - code[0] = -1.0f; - - int num_blocks = n/BLOCK_SIZE; - num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; - - pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks); - struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*)); - - for(int i = 0; i < num_blocks; i++) - args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args)); - - const uint32 elements_code = 256; - BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code); - - for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE) - { - int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - - struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx/BLOCK_SIZE; - - pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg); - } - - for(int i = 0; i < num_blocks; i++) - int err = pthread_join(threads[i], NULL); - - free(threads); - for(int i = 0; i < num_blocks; i++) - free(args[i]); - free(args); -} - - -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) -{ - for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE) - { - int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - for (int i = block_idx; i < block_end; i++) - out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE]; - } -} - void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) { int threads = 512; @@ -178,7 +68,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, +template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) @@ -189,7 +79,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, { case ADAM: if(max_unorm > 0.0f) - { + { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -202,7 +92,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, case ADAGRAD: if(max_unorm > 0.0f) - { + { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -218,7 +108,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, - float eps, int step, float lr, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1, float* max2, float* new_max1, float* new_max2, float weight_decay, @@ -261,7 +151,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, #define NUM_1STATE 8 template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) { diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 1bc13fb..8fb4cec 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -68,16 +68,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n); -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n); -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n); - void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); #endif - - - - - - - diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index e0b0d59..c2fed6b 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -3,7 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. +#if BUILD_CUDA #include <ops.cuh> +#endif +#include <cpu_ops.h> // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to @@ -12,6 +15,7 @@ // UNMANGLED CALLS //=================================================================================== +#if BUILD_CUDA void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); } @@ -78,9 +82,11 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); } +#endif extern "C" { + #if BUILD_CUDA void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } @@ -147,11 +153,10 @@ extern "C" void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } + void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } - - void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } } - |