summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-09-13 10:37:53 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-09-13 10:37:53 -0700
commitc05dd42ddd123a491b4e9840ee0c7a69cf5aa3d8 (patch)
treebdcac851548262e486fcb877ab992133c7be3dbd /tests
parentd8dbf3a9b587d9b559207feed93578810c9c2aaf (diff)
Fixed cpu blockwise quantization for small input tensors.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_functional.py30
1 files changed, 15 insertions, 15 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index d07affe..fcfdc72 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -2133,18 +2133,18 @@ def test_blockwise_cpu_large():
reldiffs = []
batch = 128
seq = 128
- hidden = 14336
- for blocksize in [4096, 16384]:
- for i in range(2):
- A1 = torch.randn(batch, seq, hidden, device='cpu')
- t0 = time.time()
- C, S = F.quantize_blockwise(A1, blocksize=blocksize)
- A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
- print(time.time() - t0)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- assert diffs[-1] < 0.011
- # print(sum(diffs)/len(diffs))
- # print(sum(reldiffs)/len(reldiffs))
+ for hidden in [128, 14336]:
+ for blocksize in [4096, 16384]:
+ for i in range(2):
+ A1 = torch.randn(batch, seq, hidden, device='cpu')
+ t0 = time.time()
+ C, S = F.quantize_blockwise(A1, blocksize=blocksize)
+ A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
+ print(time.time() - t0)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ assert diffs[-1] < 0.011
+ # print(sum(diffs)/len(diffs))
+ # print(sum(reldiffs)/len(reldiffs))