summaryrefslogtreecommitdiff
path: root/tests/test_functional.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-16 10:56:17 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-16 10:56:17 -0700
commit1ed2fa2f218d8dac401f3315420ffec92014c124 (patch)
tree57863d4d1024689100c1b43caccc1d8739c58d99 /tests/test_functional.py
parent26efb154c8d77b4ede2cfc0dbd2381dd385f33e7 (diff)
Removed storage() from get_ptr; added boilerplate for bias dequant_mm.
Diffstat (limited to 'tests/test_functional.py')
-rw-r--r--tests/test_functional.py20
1 files changed, 11 insertions, 9 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index ab7d672..65bf092 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -961,20 +961,24 @@ dim4 = torch.randint(64, 1024, size=(n,)).tolist()
dims = (2,)
# ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"]
-values = list(product(dim1, dim4, dims, formatB))
+has_bias = [True, False]
+values = list(product(dim1, dim4, dims, formatB, has_bias))
names = [
- "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
+ "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values
]
-@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
-def test_dequant_mm(dim1, dim4, dims, formatB):
+@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
+def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
inner = torch.randint(1, 128, size=(1,)).item()
+ bias = None
+ if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16)
formatB = F.get_special_format_str()
for i in range(k):
A = torch.randn(dim1, inner, device="cuda")
B = torch.randn(dim4, inner, device="cuda")
C1 = torch.matmul(A.half(), B.t().half())
+ if has_bias: C1 += bias
A1, maxA = F.vectorwise_quant(A, dim=1)
B1, maxB = F.vectorwise_quant(B, dim=1)
@@ -985,17 +989,15 @@ def test_dequant_mm(dim1, dim4, dims, formatB):
C3, S = F.nvidia_transform(C2, "row", state=SC)
C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
+ if has_bias: C4 += bias
count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
n = C1.numel()
p = 0.06
- assert (
- count / n < p
- ), f"error in more than {p} of elements: {count}/{n}={count/n}"
+ assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
- C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
+ C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
torch.testing.assert_allclose(C5, C4)
- # print(C2)
n = 2