diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-09-13 10:37:53 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-09-13 10:37:53 -0700 |
commit | c05dd42ddd123a491b4e9840ee0c7a69cf5aa3d8 (patch) | |
tree | bdcac851548262e486fcb877ab992133c7be3dbd /csrc | |
parent | d8dbf3a9b587d9b559207feed93578810c9c2aaf (diff) |
Fixed cpu blockwise quantization for small input tensors.
Diffstat (limited to 'csrc')
-rw-r--r-- | csrc/cpu_ops.cpp | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp index 303e8ed..2081e68 100644 --- a/csrc/cpu_ops.cpp +++ b/csrc/cpu_ops.cpp @@ -30,11 +30,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) { - pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size); + long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); - struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *)); + struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); - for(long long i = 0; i < thread_wave_size; i++) + for(long long i = 0; i < valid_chunks; i++) args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); int chunks_processed = 0; @@ -56,14 +57,14 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); chunks_processed += 1; - if(chunks_processed == thread_wave_size){ break; } + if(chunks_processed == valid_chunks){ break; } } - for (int i = 0; i < thread_wave_size; i++) + for (int i = 0; i < valid_chunks; i++) int err = pthread_join(threads[i], NULL); free(threads); - for (int i = 0; i < thread_wave_size; i++) + for (int i = 0; i < valid_chunks; i++) free(args[i]); free(args); |