diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:40:48 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:40:48 -0700 |
commit | cc5b323876392658b1d91655f30840d24be6d821 (patch) | |
tree | 8e23e961709a3cc082a707ebc8ea0f52baee6923 /tests | |
parent | 6101a8fb9f76c2cc4018452b4420dd52e946d52b (diff) | |
parent | bd515328d70f344f935075f359c5aefc616878d5 (diff) |
Merge branch 'extract_outliers' into debug
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_functional.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py index d80a4f9..bfc3e28 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1859,3 +1859,29 @@ def test_zp(): print(err1, err2, err3, err4, err5, err6) + +def test_extract_outliers(): + for i in range(k): + shapeA = (4096, 4096*4) + idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() + #idx = torch.Tensor([0]).int().cuda() + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + outliers1 = A[:, idx.long()] + + CA, SA = F.transform(A, 'col_turing') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + torch.testing.assert_allclose(outliers1, outliers2) + + CA, SA = F.transform(A, 'col_ampere') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + torch.testing.assert_allclose(outliers1, outliers2) |