summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py16
1 files changed, 11 insertions, 5 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 65eccf2..ff48b7f 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
-def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
+def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits
p = precision_bits
- assert e+p == 7
+ has_sign = 1 if signed else 0
+ assert e+p == total_bits-has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
- for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
+ for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val)
@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
value += pval*(2**-(i+1))
pvalues.append(value)
- assert len(evalues)*len(pvalues) == 128
+ assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
values = []
for ev in evalues:
for pv in pvalues:
- values.append(-ev*pv)
+ if signed:
+ values.append(-ev*pv)
values.append(ev*pv)
+ if total_bits < 8:
+ gap = 256 - len(values)
+ for i in range(gap):
+ values.append(0)
values.sort()
code = torch.Tensor(values)
code /= code.max()