diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-03 19:49:50 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-03 19:49:50 -0700 |
commit | 1efb87d89d1c3fe532eb97847c3b48fd1a8e5d83 (patch) | |
tree | dd6b1ca29464d6c419b5c169f3d5ea946e7fce50 /tests | |
parent | 8d87c0b85214c07756b5dcdb09ceb26b0bb1cb7a (diff) |
Added FP8 quantization map.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_functional.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/test_functional.py b/tests/test_functional.py index cf26714..329b270 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large(): assert diffs[-1] < 0.011 # print(sum(diffs)/len(diffs)) # print(sum(reldiffs)/len(reldiffs)) + + + +def test_fp8_quant(): + for e_bits in range(1, 7): + p_bits = 7-e_bits + code = F.create_fp8_map(True, e_bits, p_bits).cuda() + + print(e_bits, p_bits) + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(sum(abserr)/len(abserr)) + print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(sum(abserr)/len(abserr)) + print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + print(3, sum(abserr)/len(abserr)) + print(3, sum(relerr)/len(relerr)) + |