summaryrefslogtreecommitdiff
path: root/csrc/cpu_ops.cpp
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 18:42:22 +0300
committerGitHub <noreply@github.com>2022-09-17 18:42:22 +0300
commit3634fc738bc20e4041c75544d3f678f61ce2348c (patch)
tree36bc3394748ce4141fa9ab9d1104ca6441ade65c /csrc/cpu_ops.cpp
parente2a75769f22bdc5465240c3f6701a1b002e8ab59 (diff)
parent9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff)
Merge branch 'TimDettmers:main' into memory-efficient-backward
Diffstat (limited to 'csrc/cpu_ops.cpp')
-rw-r--r--csrc/cpu_ops.cpp90
1 files changed, 53 insertions, 37 deletions
diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp
index 89de52d..2081e68 100644
--- a/csrc/cpu_ops.cpp
+++ b/csrc/cpu_ops.cpp
@@ -4,54 +4,70 @@
using namespace BinSearch;
-void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
- for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
- int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
- int block_end = block_idx + valid_items;
- for (int i = block_idx; i < block_end; i++)
- out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
+ for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
+ long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
+ long long block_end = block_idx + valid_items;
+ for (long long i = block_idx; i < block_end; i++)
+ out[i] = code[A[i]] * absmax[block_idx / blocksize];
}
}
-void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
+{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
- int num_blocks = n / BLOCK_SIZE;
- num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
-
- pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
- struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
-
- for (int i = 0; i < num_blocks; i++)
- args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
+ long long num_blocks = n / blocksize;
+ num_blocks += n % blocksize == 0 ? 0 : 1;
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
- for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
- int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
- int block_end = block_idx + valid_items;
-
- struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
- arg->bin_searcher = &bin_searcher;
- arg->code = code;
- arg->A = A;
- arg->absmax = absmax;
- arg->out = out;
- arg->block_end = block_end;
- arg->block_idx = block_idx;
- arg->threadidx = block_idx / BLOCK_SIZE;
-
- pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
- }
+ int thread_wave_size = 256;
+ // we chunk the thresds into waves of 256 since the max limit is
+ // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
+ for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
+ {
+ long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset;
+ pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks);
+
+ struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *));
+
+ for(long long i = 0; i < valid_chunks; i++)
+ args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
- for (int i = 0; i < num_blocks; i++)
- int err = pthread_join(threads[i], NULL);
+ int chunks_processed = 0;
+ for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
+ {
+ long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
+ long long block_end = block_idx + valid_items;
+
+ struct quantize_block_args *arg = args[chunks_processed];
+ arg->bin_searcher = &bin_searcher;
+ arg->code = code;
+ arg->A = A;
+ arg->absmax = absmax;
+ arg->out = out;
+ arg->block_end = block_end;
+ arg->block_idx = block_idx;
+ arg->threadidx = block_idx / blocksize;
+ arg->blocksize = blocksize;
+
+ pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
+ chunks_processed += 1;
+ if(chunks_processed == valid_chunks){ break; }
+ }
+
+ for (int i = 0; i < valid_chunks; i++)
+ int err = pthread_join(threads[i], NULL);
+
+ free(threads);
+ for (int i = 0; i < valid_chunks; i++)
+ free(args[i]);
+ free(args);
+
+ }
- free(threads);
- for (int i = 0; i < num_blocks; i++)
- free(args[i]);
- free(args);
}