From 98cbc4bc4f15f5c094cd8575ddb0380a19516099 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 11:59:37 -0800 Subject: Added k-bit fp8 map. --- bitsandbytes/functional.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65eccf2..ff48b7f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8): return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) -def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits p = precision_bits - assert e+p == 7 + has_sign = 1 if signed else 0 + assert e+p == total_bits-has_sign # the exponent is biased to 2^(e-1) -1 == 0 evalues = [] pvalues = [] - for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): evalues.append(2**val) @@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): value += pval*(2**-(i+1)) pvalues.append(value) - assert len(evalues)*len(pvalues) == 128 + assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign) values = [] for ev in evalues: for pv in pvalues: - values.append(-ev*pv) + if signed: + values.append(-ev*pv) values.append(ev*pv) + if total_bits < 8: + gap = 256 - len(values) + for i in range(gap): + values.append(0) values.sort() code = torch.Tensor(values) code /= code.max() -- cgit v1.2.3