diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-17 18:42:22 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-17 18:42:22 +0300 |
commit | 3634fc738bc20e4041c75544d3f678f61ce2348c (patch) | |
tree | 36bc3394748ce4141fa9ab9d1104ca6441ade65c /tests/test_functional.py | |
parent | e2a75769f22bdc5465240c3f6701a1b002e8ab59 (diff) | |
parent | 9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff) |
Merge branch 'TimDettmers:main' into memory-efficient-backward
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r-- | tests/test_functional.py | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py index 14cc21e..fcfdc72 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1815,14 +1815,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): batch_size = 1 seqdim = 1 values = [] -#values.append((batch_size, seqdim, 768, 4 * 768)) +values.append((batch_size, seqdim, 768, 4 * 768)) # values.append((batch_size, seqdim, 1024, 4*1024)) # values.append((batch_size, seqdim, 1536, 4*1536)) # values.append((batch_size, seqdim, 2048, 4*2048)) # values.append((batch_size, seqdim, 2560, 4*2560)) # values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 5140, 4*5140)) -values.append((batch_size, seqdim, 12288, 4*12288)) +#values.append((batch_size, seqdim, 12288, 4*12288)) names = [ "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values ] @@ -2125,3 +2125,26 @@ def test_extract_outliers(): assert outliers2.shape[1] == idx.numel() torch.testing.assert_allclose(outliers1, outliers2) + + + +def test_blockwise_cpu_large(): + diffs = [] + reldiffs = [] + batch = 128 + seq = 128 + for hidden in [128, 14336]: + for blocksize in [4096, 16384]: + for i in range(2): + A1 = torch.randn(batch, seq, hidden, device='cpu') + t0 = time.time() + C, S = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) + print(time.time() - t0) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) |