summaryrefslogtreecommitdiff
path: root/csrc/cpu_ops.cpp
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-09-13 10:37:53 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-09-13 10:37:53 -0700
commitc05dd42ddd123a491b4e9840ee0c7a69cf5aa3d8 (patch)
treebdcac851548262e486fcb877ab992133c7be3dbd /csrc/cpu_ops.cpp
parentd8dbf3a9b587d9b559207feed93578810c9c2aaf (diff)
Fixed cpu blockwise quantization for small input tensors.
Diffstat (limited to 'csrc/cpu_ops.cpp')
-rw-r--r--csrc/cpu_ops.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp
index 303e8ed..2081e68 100644
--- a/csrc/cpu_ops.cpp
+++ b/csrc/cpu_ops.cpp
@@ -30,11 +30,12 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
// between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
{
- pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size);
+ long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
+ pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
- struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *));
+ struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
- for(long long i = 0; i < thread_wave_size; i++)
+ for(long long i = 0; i < valid_chunks; i++)
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
int chunks_processed = 0;
@@ -56,14 +57,14 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
chunks_processed += 1;
- if(chunks_processed == thread_wave_size){ break; }
+ if(chunks_processed == valid_chunks){ break; }
}
- for (int i = 0; i < thread_wave_size; i++)
+ for (int i = 0; i < valid_chunks; i++)
int err = pthread_join(threads[i], NULL);
free(threads);
- for (int i = 0; i < thread_wave_size; i++)
+ for (int i = 0; i < valid_chunks; i++)
free(args[i]);
free(args);