summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index d9249b1..662e806 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -503,7 +503,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
- assert blocksize in [4096, 2048, 1024, 512]
+ assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
@@ -586,8 +586,8 @@ def dequantize_blockwise(
if A.device.type != 'cpu':
device = pre_call(A.device)
code = code.to(A.device)
- if blocksize not in [2048, 4096, 1024, 512]:
- raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]")
+ if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
+ raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))