summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index c104ebd..d7e186f 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -6,6 +6,7 @@ import ctypes as ct
import operator
import random
import torch
+import itertools
from typing import Tuple
from torch import Tensor
@@ -136,6 +137,39 @@ def create_linear_map(signed=True):
return torch.linspace(0.0, 1.0, 256)
+def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
+ e = exponent_bits
+ p = precision_bits
+ assert e+p == 7
+ # the exponent is biased to 2^(e-1) -1 == 0
+ evalues = []
+ pvalues = []
+ for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
+ evalues.append(2**val)
+
+
+ lst = list(itertools.product([0, 1], repeat=precision_bits))
+ for bit_pattern in lst:
+ value = 1
+ for i, pval in enumerate(list(bit_pattern)):
+ value += pval*(2**-(i+1))
+ pvalues.append(value)
+
+ assert len(evalues)*len(pvalues) == 128
+ values = []
+ for ev in evalues:
+ for pv in pvalues:
+ values.append(-ev*pv)
+ values.append(ev*pv)
+ values.sort()
+ code = torch.Tensor(values)
+ code /= code.max()
+ code[127] = 0
+
+ return code
+
+
+
def create_dynamic_map(signed=True, n=7):
"""
Creates the dynamic quantiztion map.