summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_functional.py43
1 files changed, 37 insertions, 6 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index bd4dafe..99885da 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -6,9 +6,11 @@ from itertools import product
import einops
import pytest
import torch
+import numpy as np
import bitsandbytes as bnb
from bitsandbytes import functional as F
+from scipy.stats import norm
torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
@@ -2094,8 +2096,12 @@ def test_fp8_quant():
def test_few_bit_quant():
+ print('')
for bits in range(2, 9):
- for method in ['linear', 'fp8']:
+ print('='*30, bits, '='*30)
+ for method in ['linear', 'fp8', 'dynamic', 'quantile']:
+ abserrs = []
+ relerrs = []
code = None
if method == 'linear':
code = F.create_linear_map(True, bits=bits).cuda()
@@ -2103,10 +2109,21 @@ def test_few_bit_quant():
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
- print(ebits, pbits, bits)
- print(code)
+ elif method == 'dynamic':
+ code = F.create_dynamic_map(True, bits-0, bits).cuda()
+ elif method == 'quantile':
+ values = torch.randn(2048, 2048, device='cuda')
+ q = F.estimate_quantiles(values, offset= 1/(2*(2**bits)), num_quantiles=2**bits)
+ gap = 256-q.numel()
+ q = q.tolist()
+ for i in range(gap):
+ q.append(0)
+ q = torch.Tensor(q).cuda()
+
+ q /= q.abs().max()
+ code, idx = torch.sort(q)
+ print(method, (code==0).sum())
assert code.numel() == 256
- print(bits)
for i in range(10):
values = torch.randn(1, 32, device='cuda')
@@ -2127,11 +2144,25 @@ def test_few_bit_quant():
v2 = F.dequantize(q2, S2)
idx = torch.isclose(q1.int(), q2.int())
+ err2 = torch.abs(v2-values)
+ abserrs.append(err2.mean().item())
+ relerrs.append((err2/(1e-10+values).abs()).mean().item())
if idx.sum():
# some weird cases
err1 = torch.abs(v1-values).mean()
- err2 = torch.abs(v2-values).mean()
- assert err2 <= err1
+ assert err2.mean() <= err1
else:
torch.testing.assert_allclose(q1, q2)
+ print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
+
+
+def test_kbit_quantile_estimation():
+ for i in range(100):
+ data = torch.randn(1024, 1024, device='cuda')
+ for bits in range(2, 9):
+ p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
+ val1 = torch.Tensor(norm.ppf(p)).cuda()
+ val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
+ err = torch.abs(val1-val2).mean()
+ assert err < 0.035