summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/functional.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index aef6971..6278db9 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
"""
+ prev_device = pre_call(A.device)
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
is_on_gpu([code, A, absmax, out, rand])
cblocksize = ct.c_int32(blocksize)
if rand is not None:
+ is_on_gpu([code, A, out, absmax, rand])
assert blocksize==4096
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
@@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
+ is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16:
@@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
# cpu
assert rand is None
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
+ post_call(A.device)
return out, (absmax, code)
@@ -537,6 +541,7 @@ def dequantize_blockwise(
Dequantized tensor (default: float32)
"""
assert quant_state is not None or absmax is not None
+ device = pre_call(A.device)
if code is None and quant_state is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -561,6 +566,7 @@ def dequantize_blockwise(
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
+ post_call(A.device)
return out