summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-09-11 11:55:09 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-09-11 11:55:09 -0700
commit19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 (patch)
treec6c29473641febdcf5598fb6ce7ced5452469117
parentf0ae860c86039d1c1e41166aaf2153a5bd9b9a89 (diff)
Fixed 2^31 max size issue for cpu blockwise quant.
-rw-r--r--bitsandbytes/functional.py90
-rw-r--r--csrc/common.cpp8
-rw-r--r--csrc/common.h10
-rw-r--r--csrc/cpu_ops.cpp89
-rw-r--r--csrc/cpu_ops.h7
-rw-r--r--csrc/pythonInterface.c4
-rw-r--r--tests/test_functional.py27
7 files changed, 107 insertions, 128 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 22200f2..c104ebd 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -369,13 +369,7 @@ def estimate_quantiles(
return out
-def quantize_blockwise(
- A: Tensor,
- code: Tensor = None,
- absmax: Tensor = None,
- rand=None,
- out: Tensor = None,
-) -> Tensor:
+def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
@@ -412,9 +406,9 @@ def quantize_blockwise(
if absmax is None:
n = A.numel()
- num_blocks = 4096
- blocks = n // num_blocks
- blocks += 1 if n % num_blocks > 0 else 0
+ blocksize = (blocksize if A.device.type == 'cpu' else 4096)
+ blocks = n // blocksize
+ blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
if out is None:
@@ -426,46 +420,18 @@ def quantize_blockwise(
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
- lib.cquantize_blockwise_stochastic_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- get_ptr(rand),
- ct.c_int32(rand_offset),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_stochastic_fp16(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- get_ptr(rand),
- ct.c_int32(rand_offset),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
if A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
@@ -473,13 +439,7 @@ def quantize_blockwise(
else:
# cpu
assert rand is None
- lib.cquantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out, (absmax, code)
@@ -529,43 +489,21 @@ def dequantize_blockwise(
if quant_state is None:
quant_state = (absmax, code)
- if blocksize not in [2048, 4096]:
- raise ValueError(
- f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
- )
if A.device.type != 'cpu':
+ if blocksize not in [2048, 4096]:
+ raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
- lib.cdequantize_blockwise_fp32(
- get_ptr(quant_state[1]),
- get_ptr(A),
- get_ptr(quant_state[0]),
- get_ptr(out),
- ct.c_int(blocksize),
- ct.c_int(A.numel()),
- )
+ lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
- lib.cdequantize_blockwise_fp16(
- get_ptr(quant_state[1]),
- get_ptr(A),
- get_ptr(quant_state[0]),
- get_ptr(out),
- ct.c_int(blocksize),
- ct.c_int(A.numel()),
- )
+ lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
- lib.cdequantize_blockwise_cpu_fp32(
- get_ptr(quant_state[1]),
- get_ptr(A),
- get_ptr(quant_state[0]),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out
diff --git a/csrc/common.cpp b/csrc/common.cpp
index 972602b..52f0299 100644
--- a/csrc/common.cpp
+++ b/csrc/common.cpp
@@ -12,16 +12,16 @@ void *quantize_block(void *arguments) {
// 1. find absmax in block
float absmax_block = -FLT_MAX;
- for (int i = args->block_idx; i < args->block_end; i++)
+ for (long long i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));
- args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;
+ args->absmax[args->block_idx / args->blocksize] = absmax_block;
- for (int i = args->block_idx; i < args->block_end; i++) {
+ for (long long i = args->block_idx; i < args->block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i] / absmax_block;
- int idx = args->bin_searcher->scalar(normed_value);
+ long long idx = args->bin_searcher->scalar(normed_value);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
diff --git a/csrc/common.h b/csrc/common.h
index 2f25a58..c99034e 100644
--- a/csrc/common.h
+++ b/csrc/common.h
@@ -5,18 +5,20 @@
using namespace BinSearch;
+#define BLOCK_SIZE 16384
+
struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
- int block_end;
- int block_idx;
- int threadidx;
+ long long block_end;
+ long long block_idx;
+ long long threadidx;
+ long long blocksize;
};
-#define BLOCK_SIZE 4096
void *quantize_block(void *arguments);
diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp
index 89de52d..303e8ed 100644
--- a/csrc/cpu_ops.cpp
+++ b/csrc/cpu_ops.cpp
@@ -4,54 +4,69 @@
using namespace BinSearch;
-void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
- for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
- int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
- int block_end = block_idx + valid_items;
- for (int i = block_idx; i < block_end; i++)
- out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) {
+ for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
+ long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
+ long long block_end = block_idx + valid_items;
+ for (long long i = block_idx; i < block_end; i++)
+ out[i] = code[A[i]] * absmax[block_idx / blocksize];
}
}
-void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n)
+{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
- int num_blocks = n / BLOCK_SIZE;
- num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
-
- pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
- struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
-
- for (int i = 0; i < num_blocks; i++)
- args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
+ long long num_blocks = n / blocksize;
+ num_blocks += n % blocksize == 0 ? 0 : 1;
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
- for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
- int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
- int block_end = block_idx + valid_items;
-
- struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
- arg->bin_searcher = &bin_searcher;
- arg->code = code;
- arg->A = A;
- arg->absmax = absmax;
- arg->out = out;
- arg->block_end = block_end;
- arg->block_idx = block_idx;
- arg->threadidx = block_idx / BLOCK_SIZE;
-
- pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
- }
+ int thread_wave_size = 256;
+ // we chunk the thresds into waves of 256 since the max limit is
+ // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size)
+ for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size)
+ {
+ pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * thread_wave_size);
+
+ struct quantize_block_args **args = (quantize_block_args **) malloc(thread_wave_size * sizeof(quantize_block_args *));
+
+ for(long long i = 0; i < thread_wave_size; i++)
+ args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
- for (int i = 0; i < num_blocks; i++)
- int err = pthread_join(threads[i], NULL);
+ int chunks_processed = 0;
+ for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize)
+ {
+ long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx;
+ long long block_end = block_idx + valid_items;
+
+ struct quantize_block_args *arg = args[chunks_processed];
+ arg->bin_searcher = &bin_searcher;
+ arg->code = code;
+ arg->A = A;
+ arg->absmax = absmax;
+ arg->out = out;
+ arg->block_end = block_end;
+ arg->block_idx = block_idx;
+ arg->threadidx = block_idx / blocksize;
+ arg->blocksize = blocksize;
+
+ pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg);
+ chunks_processed += 1;
+ if(chunks_processed == thread_wave_size){ break; }
+ }
+
+ for (int i = 0; i < thread_wave_size; i++)
+ int err = pthread_join(threads[i], NULL);
+
+ free(threads);
+ for (int i = 0; i < thread_wave_size; i++)
+ free(args[i]);
+ free(args);
+
+ }
- free(threads);
- for (int i = 0; i < num_blocks; i++)
- free(args[i]);
- free(args);
}
diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h
index 57145a9..2ddf81e 100644
--- a/csrc/cpu_ops.h
+++ b/csrc/cpu_ops.h
@@ -1,9 +1,10 @@
#ifndef BITSANDBYTES_CPU_OPS_H
#define BITSANDBYTES_CPU_OPS_H
+#include <iostream>
+#include <stdio.h>
-void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
-
-void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
+void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
+void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
#endif
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 0707674..58e26a9 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -287,7 +287,7 @@ extern "C"
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
#endif
- void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
- void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
+ void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
+ void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
}
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 14cc21e..d07affe 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1815,14 +1815,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size = 1
seqdim = 1
values = []
-#values.append((batch_size, seqdim, 768, 4 * 768))
+values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
-values.append((batch_size, seqdim, 12288, 4*12288))
+#values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@@ -2125,3 +2125,26 @@ def test_extract_outliers():
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2)
+
+
+
+def test_blockwise_cpu_large():
+ diffs = []
+ reldiffs = []
+ batch = 128
+ seq = 128
+ hidden = 14336
+ for blocksize in [4096, 16384]:
+ for i in range(2):
+ A1 = torch.randn(batch, seq, hidden, device='cpu')
+ t0 = time.time()
+ C, S = F.quantize_blockwise(A1, blocksize=blocksize)
+ A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
+ print(time.time() - t0)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diffs[-1] < 0.011
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))