diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-06 13:05:25 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-06 13:05:25 -0800 |
commit | 2f2063bac212bcd6a515a88a12a9530b5730dabe (patch) | |
tree | 82a84ed0c97e2e53d15c91b1342ca081856e3e68 | |
parent | 98cbc4bc4f15f5c094cd8575ddb0380a19516099 (diff) |
Added k<256 quantile estimate.
-rw-r--r-- | bitsandbytes/functional.py | 61 | ||||
-rw-r--r-- | tests/test_functional.py | 43 |
2 files changed, 74 insertions, 30 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ff48b7f..076414d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) -def create_dynamic_map(signed=True, n=7): +def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7): # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - additional_items = 2 ** (7 - n) - 1 + non_sign_bits = total_bits - (1 if signed else 0) + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 if not signed: additional_items = 2 * additional_items - for i in range(n): - fraction_items = ( - 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 - ) + for i in range(max_exponent_bits): + fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(n - 1) + i)) * means).tolist() + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() if signed: - data += (-(10 ** (-(n - 1) + i)) * means).tolist() + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() - if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items + 1) - means = (boundaries[:-1] + boundaries[1:]) / 2.0 - data += ((10 ** (-(n - 1) + i)) * means).tolist() - if signed: - data += (-(10 ** (-(n - 1) + i)) * means).tolist() + if additional_items > 0: + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() data.append(0) data.append(1.0) + + gap = 256 - len(data) + for i in range(gap): + data.append(0) + data.sort() return Tensor(data) @@ -371,9 +375,7 @@ def nvidia_transform( return out, new_state -def estimate_quantiles( - A: Tensor, out: Tensor = None, offset: float = 1 / 512 -) -> Tensor: +def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -393,25 +395,36 @@ def estimate_quantiles( out : torch.Tensor Tensor with the 256 estimated quantiles. offset : float - The offset for the first and last quantile from 0 and 1. Default: 1/512 + The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) + num_quantiles : int + The number of equally spaced quantiles. Returns ------- torch.Tensor: The 256 quantiles in float32 datatype. ''' + if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') + if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") + if num_quantiles < 256 and offset == 1/(512): + # override default arguments + offset = 1/(2*num_quantiles) + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) + device = pre_call(A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16( - get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) - ) + lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) else: raise NotImplementedError(f"Not supported data type {A.dtype}") + post_call(device) + + if num_quantiles < 256: + idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) + out = out[idx] + return out diff --git a/tests/test_functional.py b/tests/test_functional.py index bd4dafe..99885da 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -6,9 +6,11 @@ from itertools import product import einops import pytest import torch +import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F +from scipy.stats import norm torch.set_printoptions( precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 @@ -2094,8 +2096,12 @@ def test_fp8_quant(): def test_few_bit_quant(): + print('') for bits in range(2, 9): - for method in ['linear', 'fp8']: + print('='*30, bits, '='*30) + for method in ['linear', 'fp8', 'dynamic', 'quantile']: + abserrs = [] + relerrs = [] code = None if method == 'linear': code = F.create_linear_map(True, bits=bits).cuda() @@ -2103,10 +2109,21 @@ def test_few_bit_quant(): ebits = math.ceil(bits/2) pbits = bits-ebits-1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - print(ebits, pbits, bits) - print(code) + elif method == 'dynamic': + code = F.create_dynamic_map(True, bits-0, bits).cuda() + elif method == 'quantile': + values = torch.randn(2048, 2048, device='cuda') + q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits) + gap = 256-q.numel() + q = q.tolist() + for i in range(gap): + q.append(0) + q = torch.Tensor(q).cuda() + + q /= q.abs().max() + code, idx = torch.sort(q) + print(method, (code==0).sum()) assert code.numel() == 256 - print(bits) for i in range(10): values = torch.randn(1, 32, device='cuda') @@ -2127,11 +2144,25 @@ def test_few_bit_quant(): v2 = F.dequantize(q2, S2) idx = torch.isclose(q1.int(), q2.int()) + err2 = torch.abs(v2-values) + abserrs.append(err2.mean().item()) + relerrs.append((err2/(1e-10+values).abs()).mean().item()) if idx.sum(): # some weird cases err1 = torch.abs(v1-values).mean() - err2 = torch.abs(v2-values).mean() - assert err2 <= err1 + assert err2.mean() <= err1 else: torch.testing.assert_allclose(q1, q2) + print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + + +def test_kbit_quantile_estimation(): + for i in range(100): + data = torch.randn(1024, 1024, device='cuda') + for bits in range(2, 9): + p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) + err = torch.abs(val1-val2).mean() + assert err < 0.035 |