diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-09-11 11:55:09 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-09-11 11:55:09 -0700 |
commit | 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 (patch) | |
tree | c6c29473641febdcf5598fb6ce7ced5452469117 /bitsandbytes | |
parent | f0ae860c86039d1c1e41166aaf2153a5bd9b9a89 (diff) |
Fixed 2^31 max size issue for cpu blockwise quant.
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 90 |
1 files changed, 14 insertions, 76 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 22200f2..c104ebd 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -369,13 +369,7 @@ def estimate_quantiles( return out -def quantize_blockwise( - A: Tensor, - code: Tensor = None, - absmax: Tensor = None, - rand=None, - out: Tensor = None, -) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -412,9 +406,9 @@ def quantize_blockwise( if absmax is None: n = A.numel() - num_blocks = 4096 - blocks = n // num_blocks - blocks += 1 if n % num_blocks > 0 else 0 + blocksize = (blocksize if A.device.type == 'cpu' else 4096) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) if out is None: @@ -426,46 +420,18 @@ def quantize_blockwise( assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - get_ptr(rand), - ct.c_int32(rand_offset), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - get_ptr(rand), - ct.c_int32(rand_offset), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" ) else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" @@ -473,13 +439,7 @@ def quantize_blockwise( else: # cpu assert rand is None - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) return out, (absmax, code) @@ -529,43 +489,21 @@ def dequantize_blockwise( if quant_state is None: quant_state = (absmax, code) - if blocksize not in [2048, 4096]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" - ) if A.device.type != 'cpu': + if blocksize not in [2048, 4096]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]") is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" ) else: - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) return out |