summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 18:42:22 +0300
committerGitHub <noreply@github.com>2022-09-17 18:42:22 +0300
commit3634fc738bc20e4041c75544d3f678f61ce2348c (patch)
tree36bc3394748ce4141fa9ab9d1104ca6441ade65c /bitsandbytes
parente2a75769f22bdc5465240c3f6701a1b002e8ab59 (diff)
parent9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff)
Merge branch 'TimDettmers:main' into memory-efficient-backward
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py90
1 files changed, 14 insertions, 76 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 22200f2..c104ebd 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -369,13 +369,7 @@ def estimate_quantiles(
return out
-def quantize_blockwise(
- A: Tensor,
- code: Tensor = None,
- absmax: Tensor = None,
- rand=None,
- out: Tensor = None,
-) -> Tensor:
+def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
@@ -412,9 +406,9 @@ def quantize_blockwise(
if absmax is None:
n = A.numel()
- num_blocks = 4096
- blocks = n // num_blocks
- blocks += 1 if n % num_blocks > 0 else 0
+ blocksize = (blocksize if A.device.type == 'cpu' else 4096)
+ blocks = n // blocksize
+ blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device)
if out is None:
@@ -426,46 +420,18 @@ def quantize_blockwise(
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
- lib.cquantize_blockwise_stochastic_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- get_ptr(rand),
- ct.c_int32(rand_offset),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_stochastic_fp16(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- get_ptr(rand),
- ct.c_int32(rand_offset),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
if A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
@@ -473,13 +439,7 @@ def quantize_blockwise(
else:
# cpu
assert rand is None
- lib.cquantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out, (absmax, code)
@@ -529,43 +489,21 @@ def dequantize_blockwise(
if quant_state is None:
quant_state = (absmax, code)
- if blocksize not in [2048, 4096]:
- raise ValueError(
- f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]"
- )
if A.device.type != 'cpu':
+ if blocksize not in [2048, 4096]:
+ raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
is_on_gpu([A, out])
if out.dtype == torch.float32:
- lib.cdequantize_blockwise_fp32(
- get_ptr(quant_state[1]),
- get_ptr(A),
- get_ptr(quant_state[0]),
- get_ptr(out),
- ct.c_int(blocksize),
- ct.c_int(A.numel()),
- )
+ lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
elif out.dtype == torch.float16:
- lib.cdequantize_blockwise_fp16(
- get_ptr(quant_state[1]),
- get_ptr(A),
- get_ptr(quant_state[0]),
- get_ptr(out),
- ct.c_int(blocksize),
- ct.c_int(A.numel()),
- )
+ lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
else:
raise ValueError(
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
)
else:
- lib.cdequantize_blockwise_cpu_fp32(
- get_ptr(quant_state[1]),
- get_ptr(A),
- get_ptr(quant_state[0]),
- get_ptr(out),
- ct.c_int(A.numel()),
- )
+ lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
return out