From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- tests/test_functional.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) (limited to 'tests/test_functional.py') diff --git a/tests/test_functional.py b/tests/test_functional.py index ab7d672..65bf092 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -961,20 +961,24 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist() dims = (2,) # ldb = list(range(256, 1*1024, 256)) formatB = ["col_turing", "col_ampere"] -values = list(product(dim1, dim4, dims, formatB)) +has_bias = [True, False] +values = list(product(dim1, dim4, dims, formatB, has_bias)) names = [ - "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values + "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values ] -@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) -def test_dequant_mm(dim1, dim4, dims, formatB): +@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) +def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): inner = torch.randint(1, 128, size=(1,)).item() + bias = None + if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) formatB = F.get_special_format_str() for i in range(k): A = torch.randn(dim1, inner, device="cuda") B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) + if has_bias: C1 += bias A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) @@ -985,17 +989,15 @@ def test_dequant_mm(dim1, dim4, dims, formatB): C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + if has_bias: C4 += bias count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() n = C1.numel() p = 0.06 - assert ( - count / n < p - ), f"error in more than {p} of elements: {count}/{n}={count/n}" + assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" - C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) + C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) torch.testing.assert_allclose(C5, C4) - # print(C2) n = 2 -- cgit v1.2.3