summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-06 16:36:31 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-06 16:36:31 -0800
commite0e697b150ba830d19a2f5fbeaf22f1349eddbe3 (patch)
tree493ff4d9969af01b2034ef98d94d2e2805049b81 /bitsandbytes
parent6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 (diff)
Fixed blockwise test and logic.
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 49d4db1..aef6971 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -466,7 +466,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
if absmax is None:
n = A.numel()
- blocksize = (blocksize if A.device.type == 'cpu' else 4096)
+ blocksize = (blocksize if A.device.type == 'cuda' else 4096)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
@@ -550,17 +550,15 @@ def dequantize_blockwise(
if A.device.type != 'cpu':
- if blocksize not in [2048, 4096]:
- raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
+ if blocksize not in [2048, 4096, 1024, 512]:
+ raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else:
- raise ValueError(
- f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
- )
+ raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))