summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-03 19:49:50 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-03 19:49:50 -0700
commit1efb87d89d1c3fe532eb97847c3b48fd1a8e5d83 (patch)
treedd6b1ca29464d6c419b5c169f3d5ea946e7fce50 /tests
parent8d87c0b85214c07756b5dcdb09ceb26b0bb1cb7a (diff)
Added FP8 quantization map.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_functional.py51
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py
index cf26714..329b270 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
+
+
+
+def test_fp8_quant():
+ for e_bits in range(1, 7):
+ p_bits = 7-e_bits
+ code = F.create_fp8_map(True, e_bits, p_bits).cuda()
+
+ print(e_bits, p_bits)
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1, code=code)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ #assert diff < 0.0075
+ print(sum(abserr)/len(abserr))
+ print(sum(relerr)/len(relerr))
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1, code=code)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ #assert diff < 0.0075
+ print(sum(abserr)/len(abserr))
+ print(sum(relerr)/len(relerr))
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ #assert diff < 0.0075
+ print(3, sum(abserr)/len(abserr))
+ print(3, sum(relerr)/len(relerr))
+