diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 18:42:22 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-17 18:42:22 +0300 |
commit | 3634fc738bc20e4041c75544d3f678f61ce2348c (patch) | |
tree | 36bc3394748ce4141fa9ab9d1104ca6441ade65c /bitsandbytes | |
parent | e2a75769f22bdc5465240c3f6701a1b002e8ab59 (diff) | |
parent | 9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff) |
Merge branch 'TimDettmers:main' into memory-efficient-backward
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 90 |
1 files changed, 14 insertions, 76 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 22200f2..c104ebd 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -369,13 +369,7 @@ def estimate_quantiles( return out -def quantize_blockwise( - A: Tensor, - code: Tensor = None, - absmax: Tensor = None, - rand=None, - out: Tensor = None, -) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -412,9 +406,9 @@ def quantize_blockwise( if absmax is None: n = A.numel() - num_blocks = 4096 - blocks = n // num_blocks - blocks += 1 if n % num_blocks > 0 else 0 + blocksize = (blocksize if A.device.type == 'cpu' else 4096) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) if out is None: @@ -426,46 +420,18 @@ def quantize_blockwise( assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - get_ptr(rand), - ct.c_int32(rand_offset), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - get_ptr(rand), - ct.c_int32(rand_offset), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" ) else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" @@ -473,13 +439,7 @@ def quantize_blockwise( else: # cpu assert rand is None - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) return out, (absmax, code) @@ -529,43 +489,21 @@ def dequantize_blockwise( if quant_state is None: quant_state = (absmax, code) - if blocksize not in [2048, 4096]: - raise ValueError( - f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" - ) if A.device.type != 'cpu': + if blocksize not in [2048, 4096]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]") is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(blocksize), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError( f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" ) else: - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(quant_state[1]), - get_ptr(A), - get_ptr(quant_state[0]), - get_ptr(out), - ct.c_int(A.numel()), - ) + lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) return out |