diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_functional.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py index b508367..4d06447 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1858,18 +1858,21 @@ def test_zp(): def test_extract_outliers(): - shapeA = (128, 128) - idx = torch.randint(0, shapeA[1], size=(10,)).int() - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - outliers1 = A[:, idx.long()] - - CA, SA = F.transform(A, 'col_turing') + for i in range(k): + shapeA = (4096, 4*4096) + idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() + #idx = torch.Tensor([32]).int().cuda() + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + outliers1 = A[:, idx.long()] - outliers2 = F.extract_outliers(CA, SA, idx) + CA, SA = F.transform(A, 'col_turing') - assert outliers2.shape[0] == shapeA[0] - assert outliers2.shape[1] == idx.numel() + outliers2 = F.extract_outliers(CA, SA, idx) + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + #print(outliers1) + #print(outliers2) - torch.testing.assert_allclose(outliers1, outliers2) + torch.testing.assert_allclose(outliers1, outliers2) |