summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py61
1 files changed, 37 insertions, 24 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ff48b7f..076414d 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
-def create_dynamic_map(signed=True, n=7):
+def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
"""
Creates the dynamic quantiztion map.
@@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7):
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
- additional_items = 2 ** (7 - n) - 1
+ non_sign_bits = total_bits - (1 if signed else 0)
+ additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed:
additional_items = 2 * additional_items
- for i in range(n):
- fraction_items = (
- 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
- )
+ for i in range(max_exponent_bits):
+ fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
- data += ((10 ** (-(n - 1) + i)) * means).tolist()
+ data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
- data += (-(10 ** (-(n - 1) + i)) * means).tolist()
+ data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
- if additional_items > 0:
- boundaries = torch.linspace(0.1, 1, additional_items + 1)
- means = (boundaries[:-1] + boundaries[1:]) / 2.0
- data += ((10 ** (-(n - 1) + i)) * means).tolist()
- if signed:
- data += (-(10 ** (-(n - 1) + i)) * means).tolist()
+ if additional_items > 0:
+ boundaries = torch.linspace(0.1, 1, additional_items + 1)
+ means = (boundaries[:-1] + boundaries[1:]) / 2.0
+ data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
+ if signed:
+ data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
+
+ gap = 256 - len(data)
+ for i in range(gap):
+ data.append(0)
+
data.sort()
return Tensor(data)
@@ -371,9 +375,7 @@ def nvidia_transform(
return out, new_state
-def estimate_quantiles(
- A: Tensor, out: Tensor = None, offset: float = 1 / 512
-) -> Tensor:
+def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
@@ -393,25 +395,36 @@ def estimate_quantiles(
out : torch.Tensor
Tensor with the 256 estimated quantiles.
offset : float
- The offset for the first and last quantile from 0 and 1. Default: 1/512
+ The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
+ num_quantiles : int
+ The number of equally spaced quantiles.
Returns
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
+ if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
+ if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
+ if num_quantiles < 256 and offset == 1/(512):
+ # override default arguments
+ offset = 1/(2*num_quantiles)
+
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out])
+ device = pre_call(A.device)
if A.dtype == torch.float32:
- lib.cestimate_quantiles_fp32(
- get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
- )
+ lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cestimate_quantiles_fp16(
- get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
- )
+ lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementedError(f"Not supported data type {A.dtype}")
+ post_call(device)
+
+ if num_quantiles < 256:
+ idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
+ out = out[idx]
+
return out