summaryrefslogtreecommitdiff
path: root/tests/test_functional.py
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-17 18:42:22 +0300
committerGitHub <noreply@github.com>2022-09-17 18:42:22 +0300
commit3634fc738bc20e4041c75544d3f678f61ce2348c (patch)
tree36bc3394748ce4141fa9ab9d1104ca6441ade65c /tests/test_functional.py
parente2a75769f22bdc5465240c3f6701a1b002e8ab59 (diff)
parent9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff)
Merge branch 'TimDettmers:main' into memory-efficient-backward
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r--tests/test_functional.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 14cc21e..fcfdc72 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1815,14 +1815,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
batch_size = 1
seqdim = 1
values = []
-#values.append((batch_size, seqdim, 768, 4 * 768))
+values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
-values.append((batch_size, seqdim, 12288, 4*12288))
+#values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@@ -2125,3 +2125,26 @@ def test_extract_outliers():
assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2)
+
+
+
+def test_blockwise_cpu_large():
+ diffs = []
+ reldiffs = []
+ batch = 128
+ seq = 128
+ for hidden in [128, 14336]:
+ for blocksize in [4096, 16384]:
+ for i in range(2):
+ A1 = torch.randn(batch, seq, hidden, device='cpu')
+ t0 = time.time()
+ C, S = F.quantize_blockwise(A1, blocksize=blocksize)
+ A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
+ print(time.time() - t0)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diffs[-1] < 0.011
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))