diff options
Diffstat (limited to 'csrc')
-rw-r--r-- | csrc/common.cpp | 8 | ||||
-rw-r--r-- | csrc/common.h | 10 | ||||
-rw-r--r-- | csrc/cpu_ops.cpp | 89 | ||||
-rw-r--r-- | csrc/cpu_ops.h | 7 | ||||
-rw-r--r-- | csrc/pythonInterface.c | 4 |
5 files changed, 68 insertions, 50 deletions
diff --git a/csrc/common.cpp b/csrc/common.cpp index 972602b..52f0299 100644 --- a/csrc/common.cpp +++ b/csrc/common.cpp @@ -12,16 +12,16 @@ void *quantize_block(void *arguments) { // 1. find absmax in block float absmax_block = -FLT_MAX; - for (int i = args->block_idx; i < args->block_end; i++) + for (long long i = args->block_idx; i < args->block_end; i++) absmax_block = fmax(absmax_block, fabs(args->A[i])); - args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block; + args->absmax[args->block_idx / args->blocksize] = absmax_block; - for (int i = args->block_idx; i < args->block_end; i++) { + for (long long i = args->block_idx; i < args->block_end; i++) { // 2. divide input value by absmax to normalize into [-1.0, 1.0] // 3. do binary search to find the closest value float normed_value = args->A[i] / absmax_block; - int idx = args->bin_searcher->scalar(normed_value); + long long idx = args->bin_searcher->scalar(normed_value); // 4. check minimal distance // The binary search returns always the value to the left, which might not be the closest value diff --git a/csrc/common.h b/csrc/common.h index 2f25a58..c99034e 100644 --- a/csrc/common.h +++ b/csrc/common.h @@ -5,18 +5,20 @@ using namespace BinSearch; +#define BLOCK_SIZE 16384 + struct quantize_block_args { BinAlgo<Scalar, float, Direct2> *bin_searcher; float *code; float *A; float *absmax; unsigned char *out; - int block_end; - int block_idx; - int threadidx; + long long block_end; + long long block_idx; + long long threadidx; + long long blocksize; }; -#define BLOCK_SIZE 4096 void *quantize_block(void *arguments); diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 89de52d..303e8ed 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -4,54 +4,69 @@ using namespace BinSearch; -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) { - for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { - int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - for (int i = block_idx; i < block_end; i++) - out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE]; +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / blocksize]; } } -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) { +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) +{ // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below code[0] = -1.0f; - int num_blocks = n / BLOCK_SIZE; - num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; - - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks); - struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *)); - - for (int i = 0; i < num_blocks; i++) - args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + long long num_blocks = n / blocksize; + num_blocks += n % blocksize == 0 ? 0 : 1; const uint32 elements_code = 256; BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code); - for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { - int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - - struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx / BLOCK_SIZE; - - pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg); - } + int thread_wave_size = 256; + // we chunk the thresds into waves of 256 since the max limit is + // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) + for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) + { + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size); + + struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *)); + + for(long long i = 0; i < thread_wave_size; i++) + args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); - for (int i = 0; i < num_blocks; i++) - int err = pthread_join(threads[i], NULL); + int chunks_processed = 0; + for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) + { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + + struct quantize_block_args *arg = args[chunks_processed]; + arg->bin_searcher = &bin_searcher; + arg->code = code; + arg->A = A; + arg->absmax = absmax; + arg->out = out; + arg->block_end = block_end; + arg->block_idx = block_idx; + arg->threadidx = block_idx / blocksize; + arg->blocksize = blocksize; + + pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); + chunks_processed += 1; + if(chunks_processed == thread_wave_size){ break; } + } + + for (int i = 0; i < thread_wave_size; i++) + int err = pthread_join(threads[i], NULL); + + free(threads); + for (int i = 0; i < thread_wave_size; i++) + free(args[i]); + free(args); + + } - free(threads); - for (int i = 0; i < num_blocks; i++) - free(args[i]); - free(args); } diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 57145a9..2ddf81e 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -1,9 +1,10 @@ #ifndef BITSANDBYTES_CPU_OPS_H #define BITSANDBYTES_CPU_OPS_H +#include <iostream> +#include <stdio.h> -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n); - -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n); +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 0707674..58e26a9 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -287,7 +287,7 @@ extern "C" void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } #endif - void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } - void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } } |