summaryrefslogtreecommitdiff
path: root/tests/test_functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r--tests/test_functional.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 4d06447..2d58fac 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1859,9 +1859,9 @@ def test_zp():
def test_extract_outliers():
for i in range(k):
- shapeA = (4096, 4*4096)
+ shapeA = (4096, 4096*4)
idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda()
- #idx = torch.Tensor([32]).int().cuda()
+ #idx = torch.Tensor([0]).int().cuda()
A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
outliers1 = A[:, idx.long()]
@@ -1872,7 +1872,13 @@ def test_extract_outliers():
assert outliers2.shape[0] == shapeA[0]
assert outliers2.shape[1] == idx.numel()
- #print(outliers1)
- #print(outliers2)
+ torch.testing.assert_allclose(outliers1, outliers2)
+
+ CA, SA = F.transform(A, 'col_ampere')
+
+ outliers2 = F.extract_outliers(CA, SA, idx)
+
+ assert outliers2.shape[0] == shapeA[0]
+ assert outliers2.shape[1] == idx.numel()
torch.testing.assert_allclose(outliers1, outliers2)