From c059bd284832d09bc51cf82c377642b26a48ef28 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 20 Nov 2022 14:18:15 -0800 Subject: Added additional blocksizes: {64, 128, 256}. --- bitsandbytes/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d9249b1..662e806 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -503,7 +503,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': - assert blocksize in [4096, 2048, 1024, 512] + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) @@ -586,8 +586,8 @@ def dequantize_blockwise( if A.device.type != 'cpu': device = pre_call(A.device) code = code.to(A.device) - if blocksize not in [2048, 4096, 1024, 512]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]") + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") is_on_gpu([A, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) -- cgit v1.2.3