summaryrefslogtreecommitdiff
path: root/tests/test_functional.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-26 17:39:30 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-26 17:39:30 -0700
commitbcab99ec877ba063543bd7c03ba1cdd1b06e8078 (patch)
tree3d1dcfec5d3361ad4e5c5ba552a432e045445851 /tests/test_functional.py
parentcbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada (diff)
Working outlier extraction for Turing.
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r--tests/test_functional.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index b508367..4d06447 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1858,18 +1858,21 @@ def test_zp():
def test_extract_outliers():
- shapeA = (128, 128)
- idx = torch.randint(0, shapeA[1], size=(10,)).int()
- A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
- outliers1 = A[:, idx.long()]
-
- CA, SA = F.transform(A, 'col_turing')
+ for i in range(k):
+ shapeA = (4096, 4*4096)
+ idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
+ #idx = torch.Tensor([32]).int().cuda()
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ outliers1 = A[:, idx.long()]
- outliers2 = F.extract_outliers(CA, SA, idx)
+ CA, SA = F.transform(A, 'col_turing')
- assert outliers2.shape[0] == shapeA[0]
- assert outliers2.shape[1] == idx.numel()
+ outliers2 = F.extract_outliers(CA, SA, idx)
+ assert outliers2.shape[0] == shapeA[0]
+ assert outliers2.shape[1] == idx.numel()
+ #print(outliers1)
+ #print(outliers2)
- torch.testing.assert_allclose(outliers1, outliers2)
+ torch.testing.assert_allclose(outliers1, outliers2)