From 08fa2e7b01dda8959a930295de9829516f8c77bc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 7 Nov 2022 18:06:18 -0800 Subject: Fixed bug in cpu quant; faster GPU dequant. --- bitsandbytes/cextension.py | 1 - bitsandbytes/functional.py | 22 ++++++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index ead8502..264e899 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -94,7 +94,6 @@ class CUDASetup(object): else: self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.lib = ct.cdll.LoadLibrary(binary_path) - print(self.lib) except Exception as ex: self.add_log_entry(str(ex)) self.print_log_stack() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6278db9..fffbecf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -458,16 +458,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra """ - prev_device = pre_call(A.device) if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - code = code.to(A.device) if absmax is None: n = A.numel() - blocksize = (blocksize if A.device.type == 'cuda' else 4096) blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) @@ -477,8 +474,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra if A.device.type != 'cpu': assert blocksize in [4096, 2048, 1024, 512] - is_on_gpu([code, A, absmax, out, rand]) cblocksize = ct.c_int32(blocksize) + prev_device = pre_call(A.device) + code = code.to(A.device) if rand is not None: is_on_gpu([code, A, out, absmax, rand]) assert blocksize==4096 @@ -498,11 +496,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) else: # cpu + code = code.cpu() assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - post_call(A.device) return out, (absmax, code) @@ -541,32 +540,35 @@ def dequantize_blockwise( Dequantized tensor (default: float32) """ assert quant_state is not None or absmax is not None - device = pre_call(A.device) if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - code = code.to(A.device) if out is None: out = torch.zeros_like(A, dtype=torch.float32) if quant_state is None: quant_state = (absmax, code) + else: + absmax, code = quant_state if A.device.type != 'cpu': + device = pre_call(A.device) + code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]") is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) else: + code = code.cpu() lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - post_call(A.device) return out -- cgit v1.2.3