diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 12:12:38 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 12:12:38 -0700 |
commit | cbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada (patch) | |
tree | f02615b5588aa6ed94a51c1e66297595b802c0a1 /tests | |
parent | c771b3a75a6ebbfbfc398a028a477246b0799cf0 (diff) |
Boilerplate and test for extract_outliers.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_functional.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py index 6cbe58f..b508367 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1856,3 +1856,20 @@ def test_zp(): print(err1, err2, err3, err4, err5, err6) + +def test_extract_outliers(): + shapeA = (128, 128) + idx = torch.randint(0, shapeA[1], size=(10,)).int() + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + outliers1 = A[:, idx.long()] + + CA, SA = F.transform(A, 'col_turing') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + + + torch.testing.assert_allclose(outliers1, outliers2) |