diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 61 |
1 files changed, 37 insertions, 24 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ff48b7f..076414d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) -def create_dynamic_map(signed=True, n=7): +def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7): # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - additional_items = 2 ** (7 - n) - 1 + non_sign_bits = total_bits - (1 if signed else 0) + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 if not signed: additional_items = 2 * additional_items - for i in range(n): - fraction_items = ( - 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 - ) + for i in range(max_exponent_bits): + fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(n - 1) + i)) * means).tolist() + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() if signed: - data += (-(10 ** (-(n - 1) + i)) * means).tolist() + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(n - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(n - 1) + i)) * means).tolist() + if additional_items > 0: + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() data.append(0) data.append(1.0) + + gap = 256 - len(data) + for i in range(gap): + data.append(0) + data.sort() return Tensor(data) @@ -371,9 +375,7 @@ def nvidia_transform( return out, new_state -def estimate_quantiles( - A: Tensor, out: Tensor = None, offset: float = 1 / 512 -) -> Tensor: +def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -393,25 +395,36 @@ def estimate_quantiles( out : torch.Tensor Tensor with the 256 estimated quantiles. offset : float - The offset for the first and last quantile from 0 and 1. Default: 1/512 + The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) + num_quantiles : int + The number of equally spaced quantiles. Returns ------- torch.Tensor: The 256 quantiles in float32 datatype. ''' + if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') + if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") + if num_quantiles < 256 and offset == 1/(512): + # override default arguments + offset = 1/(2*num_quantiles) + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) + device = pre_call(A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) else: raise NotImplementedError(f"Not supported data type {A.dtype}") + post_call(device) + + if num_quantiles < 256: + idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) + out = out[idx] + return out |