diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-06 13:05:25 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-06 13:05:25 -0800 |
commit | 2f2063bac212bcd6a515a88a12a9530b5730dabe (patch) | |
tree | 82a84ed0c97e2e53d15c91b1342ca081856e3e68 /tests | |
parent | 98cbc4bc4f15f5c094cd8575ddb0380a19516099 (diff) |
Added k<256 quantile estimate.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_functional.py | 43 |
1 files changed, 37 insertions, 6 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py index bd4dafe..99885da 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -6,9 +6,11 @@ from itertools import product import einops import pytest import torch +import numpy as np import bitsandbytes as bnb from bitsandbytes import functional as F +from scipy.stats import norm torch.set_printoptions( precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 @@ -2094,8 +2096,12 @@ def test_fp8_quant(): def test_few_bit_quant(): + print('') for bits in range(2, 9): - for method in ['linear', 'fp8']: + print('='*30, bits, '='*30) + for method in ['linear', 'fp8', 'dynamic', 'quantile']: + abserrs = [] + relerrs = [] code = None if method == 'linear': code = F.create_linear_map(True, bits=bits).cuda() @@ -2103,10 +2109,21 @@ def test_few_bit_quant(): ebits = math.ceil(bits/2) pbits = bits-ebits-1 code = F.create_fp8_map(True, ebits, pbits, bits).cuda() - print(ebits, pbits, bits) - print(code) + elif method == 'dynamic': + code = F.create_dynamic_map(True, bits-0, bits).cuda() + elif method == 'quantile': + values = torch.randn(2048, 2048, device='cuda') + q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits) + gap = 256-q.numel() + q = q.tolist() + for i in range(gap): + q.append(0) + q = torch.Tensor(q).cuda() + + q /= q.abs().max() + code, idx = torch.sort(q) + print(method, (code==0).sum()) assert code.numel() == 256 - print(bits) for i in range(10): values = torch.randn(1, 32, device='cuda') @@ -2127,11 +2144,25 @@ def test_few_bit_quant(): v2 = F.dequantize(q2, S2) idx = torch.isclose(q1.int(), q2.int()) + err2 = torch.abs(v2-values) + abserrs.append(err2.mean().item()) + relerrs.append((err2/(1e-10+values).abs()).mean().item()) if idx.sum(): # some weird cases err1 = torch.abs(v1-values).mean() - err2 = torch.abs(v2-values).mean() - assert err2 <= err1 + assert err2.mean() <= err1 else: torch.testing.assert_allclose(q1, q2) + print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + + +def test_kbit_quantile_estimation(): + for i in range(100): + data = torch.randn(1024, 1024, device='cuda') + for bits in range(2, 9): + p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) + err = torch.abs(val1-val2).mean() + assert err < 0.035 |