summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/functional.py61
-rw-r--r--tests/test_functional.py43
2 files changed, 74 insertions, 30 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ff48b7f..076414d 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -182,7 +182,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
-def create_dynamic_map(signed=True, n=7):
+def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
"""
Creates the dynamic quantiztion map.
@@ -203,28 +203,32 @@ def create_dynamic_map(signed=True, n=7):
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
- additional_items = 2 ** (7 - n) - 1
+ non_sign_bits = total_bits - (1 if signed else 0)
+ additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
if not signed:
additional_items = 2 * additional_items
- for i in range(n):
- fraction_items = (
- 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
- )
+ for i in range(max_exponent_bits):
+ fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
- data += ((10 ** (-(n - 1) + i)) * means).tolist()
+ data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
- data += (-(10 ** (-(n - 1) + i)) * means).tolist()
+ data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
- if additional_items > 0:
- boundaries = torch.linspace(0.1, 1, additional_items + 1)
- means = (boundaries[:-1] + boundaries[1:]) / 2.0
- data += ((10 ** (-(n - 1) + i)) * means).tolist()
- if signed:
- data += (-(10 ** (-(n - 1) + i)) * means).tolist()
+ if additional_items > 0:
+ boundaries = torch.linspace(0.1, 1, additional_items + 1)
+ means = (boundaries[:-1] + boundaries[1:]) / 2.0
+ data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
+ if signed:
+ data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
data.append(0)
data.append(1.0)
+
+ gap = 256 - len(data)
+ for i in range(gap):
+ data.append(0)
+
data.sort()
return Tensor(data)
@@ -371,9 +375,7 @@ def nvidia_transform(
return out, new_state
-def estimate_quantiles(
- A: Tensor, out: Tensor = None, offset: float = 1 / 512
-) -> Tensor:
+def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
@@ -393,25 +395,36 @@ def estimate_quantiles(
out : torch.Tensor
Tensor with the 256 estimated quantiles.
offset : float
- The offset for the first and last quantile from 0 and 1. Default: 1/512
+ The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
+ num_quantiles : int
+ The number of equally spaced quantiles.
Returns
-------
torch.Tensor:
The 256 quantiles in float32 datatype.
'''
+ if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
+ if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
+ if num_quantiles < 256 and offset == 1/(512):
+ # override default arguments
+ offset = 1/(2*num_quantiles)
+
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
is_on_gpu([A, out])
+ device = pre_call(A.device)
if A.dtype == torch.float32:
- lib.cestimate_quantiles_fp32(
- get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
- )
+ lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cestimate_quantiles_fp16(
- get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
- )
+ lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementedError(f"Not supported data type {A.dtype}")
+ post_call(device)
+
+ if num_quantiles < 256:
+ idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
+ out = out[idx]
+
return out
diff --git a/tests/test_functional.py b/tests/test_functional.py
index bd4dafe..99885da 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -6,9 +6,11 @@ from itertools import product
import einops
import pytest
import torch
+import numpy as np
import bitsandbytes as bnb
from bitsandbytes import functional as F
+from scipy.stats import norm
torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
@@ -2094,8 +2096,12 @@ def test_fp8_quant():
def test_few_bit_quant():
+ print('')
for bits in range(2, 9):
- for method in ['linear', 'fp8']:
+ print('='*30, bits, '='*30)
+ for method in ['linear', 'fp8', 'dynamic', 'quantile']:
+ abserrs = []
+ relerrs = []
code = None
if method == 'linear':
code = F.create_linear_map(True, bits=bits).cuda()
@@ -2103,10 +2109,21 @@ def test_few_bit_quant():
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
- print(ebits, pbits, bits)
- print(code)
+ elif method == 'dynamic':
+ code = F.create_dynamic_map(True, bits-0, bits).cuda()
+ elif method == 'quantile':
+ values = torch.randn(2048, 2048, device='cuda')
+ q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
+ gap = 256-q.numel()
+ q = q.tolist()
+ for i in range(gap):
+ q.append(0)
+ q = torch.Tensor(q).cuda()
+
+ q /= q.abs().max()
+ code, idx = torch.sort(q)
+ print(method, (code==0).sum())
assert code.numel() == 256
- print(bits)
for i in range(10):
values = torch.randn(1, 32, device='cuda')
@@ -2127,11 +2144,25 @@ def test_few_bit_quant():
v2 = F.dequantize(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
+ err2 = torch.abs(v2-values)
+ abserrs.append(err2.mean().item())
+ relerrs.append((err2/(1e-10+values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
- err2 = torch.abs(v2-values).mean()
- assert err2 <= err1
+ assert err2.mean() <= err1
else:
torch.testing.assert_allclose(q1, q2)
+ print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
+
+
+def test_kbit_quantile_estimation():
+ for i in range(100):
+ data = torch.randn(1024, 1024, device='cuda')
+ for bits in range(2, 9):
+ p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
+ val1 = torch.Tensor(norm.ppf(p)).cuda()
+ val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
+ err = torch.abs(val1-val2).mean()
+ assert err < 0.035