summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-03 19:49:50 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-03 19:49:50 -0700
commit1efb87d89d1c3fe532eb97847c3b48fd1a8e5d83 (patch)
treedd6b1ca29464d6c419b5c169f3d5ea946e7fce50
parent8d87c0b85214c07756b5dcdb09ceb26b0bb1cb7a (diff)
Added FP8 quantization map.
-rw-r--r--bitsandbytes/functional.py34
-rw-r--r--tests/test_functional.py51
2 files changed, 85 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index c104ebd..d7e186f 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -6,6 +6,7 @@ import ctypes as ct
import operator
import random
import torch
+import itertools
from typing import Tuple
from torch import Tensor
@@ -136,6 +137,39 @@ def create_linear_map(signed=True):
return torch.linspace(0.0, 1.0, 256)
+def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
+ e = exponent_bits
+ p = precision_bits
+ assert e+p == 7
+ # the exponent is biased to 2^(e-1) -1 == 0
+ evalues = []
+ pvalues = []
+ for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
+ evalues.append(2**val)
+
+
+ lst = list(itertools.product([0, 1], repeat=precision_bits))
+ for bit_pattern in lst:
+ value = 1
+ for i, pval in enumerate(list(bit_pattern)):
+ value += pval*(2**-(i+1))
+ pvalues.append(value)
+
+ assert len(evalues)*len(pvalues) == 128
+ values = []
+ for ev in evalues:
+ for pv in pvalues:
+ values.append(-ev*pv)
+ values.append(ev*pv)
+ values.sort()
+ code = torch.Tensor(values)
+ code /= code.max()
+ code[127] = 0
+
+ return code
+
+
+
def create_dynamic_map(signed=True, n=7):
"""
Creates the dynamic quantiztion map.
diff --git a/tests/test_functional.py b/tests/test_functional.py
index cf26714..329b270 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -2040,3 +2040,54 @@ def test_blockwise_cpu_large():
assert diffs[-1] < 0.011
# print(sum(diffs)/len(diffs))
# print(sum(reldiffs)/len(reldiffs))
+
+
+
+def test_fp8_quant():
+ for e_bits in range(1, 7):
+ p_bits = 7-e_bits
+ code = F.create_fp8_map(True, e_bits, p_bits).cuda()
+
+ print(e_bits, p_bits)
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1, code=code)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ #assert diff < 0.0075
+ print(sum(abserr)/len(abserr))
+ print(sum(relerr)/len(relerr))
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1, code=code)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ #assert diff < 0.0075
+ print(sum(abserr)/len(abserr))
+ print(sum(relerr)/len(relerr))
+
+ abserr = []
+ relerr = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, SC = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, SC)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff/torch.abs(A1+1e-8)
+ abserr.append(diff.mean().item())
+ relerr.append(reldiff.mean().item())
+ #assert diff < 0.0075
+ print(3, sum(abserr)/len(abserr))
+ print(3, sum(relerr)/len(relerr))
+