diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2021-10-05 19:16:20 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2021-10-05 19:16:20 -0700 |
commit | 7439924891496025edf60c9da6a782f362a50c70 (patch) | |
tree | 90476984d2c267f89232577a2ea40eb172387475 /tests/test_functional.py |
Initial commit
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r-- | tests/test_functional.py | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100644 index 0000000..2a7d308 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import bitsandbytes as bnb + +from itertools import product + +from bitsandbytes import functional as F + +def setup(): + pass + +def teardown(): + pass + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half']) +def test_estimate_quantiles(dtype): + A = torch.rand(1024, 1024, device='cuda') + A = A.to(dtype) + code = F.estimate_quantiles(A) + + percs = torch.linspace(1/512, 511/512, 256, device=A.device) + torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) + + A = torch.randn(1024, 1024, device='cuda') + A = A.to(dtype) + code = F.estimate_quantiles(A) + + quantiles = torch.quantile(A.float(), percs) + diff = torch.abs(code-quantiles) + assert (diff > 5e-02).sum().item() == 0 + + +def test_quantile_quantization(): + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1-A2).mean().item() + assert diff < 0.0075 + + A1 = torch.rand(1024, 1024, device='cuda') + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1-A2).mean().item() + torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) + assert diff < 0.001 + + +def test_dynamic_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1-A2) + reldiff = diff/torch.abs(A1+1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diff.mean().item() < 0.0135 + print(sum(diffs)/len(diffs)) + print(sum(reldiffs)/len(reldiffs)) + + for i in range(100): + A1 = torch.rand(1024, 1024, device='cuda') + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1-A2).mean().item() + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + assert diff < 0.004 + + +def test_dynamic_blockwise_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2) + reldiff = diff/torch.abs(A1+1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + print(sum(diffs)/len(diffs)) + print(sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device='cuda') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2).mean().item() + assert diff < 0.0033 + diffs.append(diff) + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + #print(sum(diffs)/len(diffs)) + +def test_dynamic_blockwise_stochastic_quantization(): + diffs = [] + reldiffs = [] + rand = torch.rand(1024).cuda() + for i in range(100): + A1 = torch.randn(1024, 1024, device='cuda') + C1, S1 = F.quantize_blockwise(A1, rand=rand) + C2, S2 = F.quantize_blockwise(A1) + # a maximunm distance of quantized values of 1 + torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) + fraction_smaller = (C1<C2).float().sum()/C1.numel() + fraction_larger = (C1>C2).float().sum()/C1.numel() + torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) + + + +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half']) +def test_percentile_clipping(gtype): + gnorm_vec1 = torch.zeros(100, device='cuda') + gnorm_vec2 = torch.zeros(100, device='cuda') + n = 4 + step = 0 + percentile=5 + for i in range(1000): + step += 1 + g = torch.randn(n, n, dtype=gtype, device='cuda') + gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1 + + gnorm2 = torch.norm(g.float()) + if step == 1: + gnorm_vec1[:] = gnorm2 + else: + gnorm_vec1[step % 100] = gnorm2 + + vals, idx = torch.sort(gnorm_vec1) + clip1 = vals[percentile] + + torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_allclose(clip1, clip2) + torch.testing.assert_allclose(gnorm1, gnorm2) + + +def test_stable_embedding(): + layer = bnb.nn.StableEmbedding(1024, 1024) + layer.reset_parameters() + + +def test_dynamic_blockwise_quantization_cpu(): + #A1 = torch.randn(1024, 1024, device='cpu') + #code = F.create_dynamic_map() + #for i in range(1000): + # C, S = F.quantize_blockwise(A1, code=code) + # A2 = F.dequantize_blockwise(C, S) + + for i in range(10): + # equivalence with GPU blockwise quantization + A1 = torch.randn(1024, 1024, device='cpu') + C1, S1 = F.quantize_blockwise(A1) + C2, S2 = F.quantize_blockwise(A1.cuda()) + torch.testing.assert_allclose(S1[0], S2[0].cpu()) + # there seems to be some issues with precision in CUDA vs CPU + # not all elements are usually close, with couple off elements in a million + idx = torch.isclose(C1, C2.cpu()) + assert (idx==0).sum().item() < 15 + + + diffs = [] + reldiffs = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device='cpu') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2) + reldiff = diff/torch.abs(A1+1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + #print(sum(diffs)/len(diffs)) + #print(sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(10): + A1 = torch.rand(1024, 1024, device='cpu') + C, S = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1-A2).mean().item() + assert diff < 0.0033 + diffs.append(diff) + torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + #print(sum(diffs)/len(diffs)) + + +def test_histogram(): + dim1, dim2 = 32, 32 + source = torch.rand(dim1, dim2, device='cuda') + idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() + idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() + histogram1 = torch.zeros((256, 256)).cuda() + histogram2 = torch.zeros((256, 256)).cuda() + + F.histogram_scatter_add_2d(histogram2, idx1, idx2, source) + + for i in range(dim1): + for j in range(dim2): + histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j] + + torch.testing.assert_allclose(histogram1, histogram2) + torch.testing.assert_allclose(histogram1.sum(), source.sum()) |