diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 18:42:22 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-17 18:42:22 +0300 |
commit | 3634fc738bc20e4041c75544d3f678f61ce2348c (patch) | |
tree | 36bc3394748ce4141fa9ab9d1104ca6441ade65c /csrc/cpu_ops.cpp | |
parent | e2a75769f22bdc5465240c3f6701a1b002e8ab59 (diff) | |
parent | 9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff) |
Merge branch 'TimDettmers:main' into memory-efficient-backward
Diffstat (limited to 'csrc/cpu_ops.cpp')
-rw-r--r-- | csrc/cpu_ops.cpp | 90 |
1 files changed, 53 insertions, 37 deletions
diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 89de52d..2081e68 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -4,54 +4,70 @@ using namespace BinSearch; -void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) { - for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { - int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - for (int i = block_idx; i < block_end; i++) - out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE]; +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / blocksize]; } } -void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) { +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) +{ // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below code[0] = -1.0f; - int num_blocks = n / BLOCK_SIZE; - num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1; - - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks); - struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *)); - - for (int i = 0; i < num_blocks; i++) - args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + long long num_blocks = n / blocksize; + num_blocks += n % blocksize == 0 ? 0 : 1; const uint32 elements_code = 256; BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code); - for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) { - int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx; - int block_end = block_idx + valid_items; - - struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE]; - arg->bin_searcher = &bin_searcher; - arg->code = code; - arg->A = A; - arg->absmax = absmax; - arg->out = out; - arg->block_end = block_end; - arg->block_idx = block_idx; - arg->threadidx = block_idx / BLOCK_SIZE; - - pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg); - } + int thread_wave_size = 256; + // we chunk the thresds into waves of 256 since the max limit is + // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) + for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) + { + long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); + + struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); + + for(long long i = 0; i < valid_chunks; i++) + args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); - for (int i = 0; i < num_blocks; i++) - int err = pthread_join(threads[i], NULL); + int chunks_processed = 0; + for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) + { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + + struct quantize_block_args *arg = args[chunks_processed]; + arg->bin_searcher = &bin_searcher; + arg->code = code; + arg->A = A; + arg->absmax = absmax; + arg->out = out; + arg->block_end = block_end; + arg->block_idx = block_idx; + arg->threadidx = block_idx / blocksize; + arg->blocksize = blocksize; + + pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); + chunks_processed += 1; + if(chunks_processed == valid_chunks){ break; } + } + + for (int i = 0; i < valid_chunks; i++) + int err = pthread_join(threads[i], NULL); + + free(threads); + for (int i = 0; i < valid_chunks; i++) + free(args[i]); + free(args); + + } - free(threads); - for (int i = 0; i < num_blocks; i++) - free(args[i]); - free(args); } |