diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c104ebd..d7e186f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,6 +6,7 @@ import ctypes as ct import operator import random import torch +import itertools from typing import Tuple from torch import Tensor @@ -136,6 +137,39 @@ def create_linear_map(signed=True): return torch.linspace(0.0, 1.0, 256) +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): + e = exponent_bits + p = precision_bits + assert e+p == 7 + # the exponent is biased to 2^(e-1) -1 == 0 + evalues = [] + pvalues = [] + for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + evalues.append(2**val) + + + lst = list(itertools.product([0, 1], repeat=precision_bits)) + for bit_pattern in lst: + value = 1 + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + pvalues.append(value) + + assert len(evalues)*len(pvalues) == 128 + values = [] + for ev in evalues: + for pv in pvalues: + values.append(-ev*pv) + values.append(ev*pv) + values.sort() + code = torch.Tensor(values) + code /= code.max() + code[127] = 0 + + return code + + + def create_dynamic_map(signed=True, n=7): """ Creates the dynamic quantiztion map. |