summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 076414d..49d4db1 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -130,10 +130,10 @@ class Cusparse_Context(object):
return cls._instance
-def create_linear_map(signed=True, bits=8):
+def create_linear_map(signed=True, total_bits=8):
sign = (-1.0 if signed else 0.0)
- values = torch.linspace(sign, 1.0, 2**bits)
+ values = torch.linspace(sign, 1.0, 2**total_bits)
gap = 256 - values.numel()
if gap == 0:
return values
@@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
The quantization state to undo the quantization.
"""
+
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
+ assert blocksize in [4096, 2048, 1024, 512]
is_on_gpu([code, A, absmax, out, rand])
+ cblocksize = ct.c_int32(blocksize)
if rand is not None:
+ assert blocksize==4096
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
@@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else:
- raise ValueError(
- f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
- )
+ raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
if A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else:
- raise ValueError(
- f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
- )
+ raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
# cpu
assert rand is None