summaryrefslogtreecommitdiff
path: root/tests/test_functional.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-25 17:27:57 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-25 17:27:57 -0700
commit1e88edd8c096bde5202dd61411d3c8d7eda56645 (patch)
tree84e514d4538d113fe6f78985808d26cc5c677b62 /tests/test_functional.py
parent8b1fd32e3e4f5073fd055cb5f9261ec585f8cc2c (diff)
Removed rowscale (segfaults on ampere).
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r--tests/test_functional.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index dcb0255..d80a4f9 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
@@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []