From 6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 16:27:48 -0800 Subject: Added blocksizes 2048, 1024, and 512 to blockwise quant. --- bitsandbytes/cextension.py | 11 +++++- bitsandbytes/functional.py | 20 +++++----- csrc/kernels.cu | 22 ++++++++--- csrc/ops.cu | 33 +++++++++++++---- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 12 +++--- tests/test_functional.py | 92 +++++++++++++++++++++++----------------------- 7 files changed, 113 insertions(+), 79 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 8125202..ead8502 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -52,8 +52,13 @@ class CUDASetup(object): self.add_log_entry('python setup.py install') def initialize(self): - self.cuda_setup_log = [] + self.has_printed = False self.lib = None + self.run_cuda_setup() + + def run_cuda_setup(self): + self.initialized = True + self.cuda_setup_log = [] from .cuda_setup.main import evaluate_cuda_setup binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() @@ -89,7 +94,9 @@ class CUDASetup(object): else: self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.lib = ct.cdll.LoadLibrary(binary_path) - except: + print(self.lib) + except Exception as ex: + self.add_log_entry(str(ex)) self.print_log_stack() def add_log_entry(self, msg, is_warning=False): diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 076414d..49d4db1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,10 +130,10 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True, bits=8): +def create_linear_map(signed=True, total_bits=8): sign = (-1.0 if signed else 0.0) - values = torch.linspace(sign, 1.0, 2**bits) + values = torch.linspace(sign, 1.0, 2**total_bits) gap = 256 - values.numel() if gap == 0: return values @@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra The quantization state to undo the quantization. """ + if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': + assert blocksize in [4096, 2048, 1024, 512] is_on_gpu([code, A, absmax, out, rand]) + cblocksize = ct.c_int32(blocksize) if rand is not None: + assert blocksize==4096 assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: @@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra elif A.dtype == torch.float16: lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: # cpu assert rand is None diff --git a/csrc/kernels.cu b/csrc/kernels.cu index f01b4e1..9d9653c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } template -__launch_bounds__(TH, 4) +//__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM]; - float rand_vals[NUM]; - unsigned char qvals[NUM]; + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; @@ -517,8 +517,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM]; - unsigned char qvals[NUM]; + T vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; @@ -2791,11 +2791,21 @@ template __global__ void kQuantizeBlockwise(float * code, half template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index e49c94b..b121fc2 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + if(STOCHASTIC == 1) + assert(blocksize == 4096); + + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -66,6 +78,11 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo kDequantizeBlockwise<<>>(code, A, absmax, out, n); else if(blocksize == 2048) kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 1024) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 512) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -659,10 +676,10 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index acfdb06..66e3843 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -128,7 +128,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 58e26a9..5bac30e 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } @@ -140,8 +140,8 @@ extern "C" void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } - void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } - void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } + void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 99885da..b525dff 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -151,30 +151,41 @@ def test_dynamic_quantization(): def test_dynamic_blockwise_quantization(): - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diffs[-1] < 0.011 - # print(sum(diffs)/len(diffs)) - # print(sum(reldiffs)/len(reldiffs)) - - diffs = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1 - A2).mean().item() - assert diff < 0.0033 - diffs.append(diff) - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - # print(sum(diffs)/len(diffs)) + #print('') + for blocksize in [4096, 2048, 1024, 512]: + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + #print('randn', blocksize, sum(diffs)/len(diffs)) + #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + #print('rand', blocksize, sum(diffs)/len(diffs)) + #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): @@ -1618,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): # print(time.time() - t0) -def test_layout(): - a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16) - a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte() - a2, s2 = F.transform(a1, "col_turing") - print(a2.shape) - - print(a1.flatten()[8 * 64 : 8 * 64 + 32]) - for i in range(4): - print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0) - - def test_coo2csr(): threshold = 1 A = torch.randn(128, 128).half().cuda() @@ -2062,8 +2062,8 @@ def test_fp8_quant(): abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) #assert diff < 0.0075 - print(sum(abserr)/len(abserr)) - print(sum(relerr)/len(relerr)) + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2076,8 +2076,8 @@ def test_fp8_quant(): abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) #assert diff < 0.0075 - print(sum(abserr)/len(abserr)) - print(sum(relerr)/len(relerr)) + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) abserr = [] relerr = [] @@ -2090,21 +2090,21 @@ def test_fp8_quant(): abserr.append(diff.mean().item()) relerr.append(reldiff.mean().item()) #assert diff < 0.0075 - print(3, sum(abserr)/len(abserr)) - print(3, sum(relerr)/len(relerr)) + #print(3, sum(abserr)/len(abserr)) + #print(3, sum(relerr)/len(relerr)) def test_few_bit_quant(): - print('') + #print('') for bits in range(2, 9): - print('='*30, bits, '='*30) + #print('='*30, bits, '='*30) for method in ['linear', 'fp8', 'dynamic', 'quantile']: abserrs = [] relerrs = [] code = None if method == 'linear': - code = F.create_linear_map(True, bits=bits).cuda() + code = F.create_linear_map(True, total_bits=bits).cuda() elif method == 'fp8': ebits = math.ceil(bits/2) pbits = bits-ebits-1 @@ -2122,7 +2122,7 @@ def test_few_bit_quant(): q /= q.abs().max() code, idx = torch.sort(q) - print(method, (code==0).sum()) + #print(method, (code==0).sum()) assert code.numel() == 256 for i in range(10): @@ -2154,7 +2154,7 @@ def test_few_bit_quant(): else: torch.testing.assert_allclose(q1, q2) - print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) def test_kbit_quantile_estimation(): -- cgit v1.2.3