summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-19 07:24:03 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-19 07:24:03 -0800
commiteb028e6ebcddc78c7921c2524d361b23b1a1007b (patch)
tree168ea8943ed732b02e6bce171cfa11f8d935b938 /bitsandbytes
parent08fa2e7b01dda8959a930295de9829516f8c77bc (diff)
Fixed k-bit quantization maps.
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py62
1 files changed, 46 insertions, 16 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index fffbecf..d9249b1 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -7,6 +7,7 @@ import operator
import random
import torch
import itertools
+import math
from typing import Tuple
from torch import Tensor
@@ -130,10 +131,17 @@ class Cusparse_Context(object):
return cls._instance
-def create_linear_map(signed=True, total_bits=8):
+def create_linear_map(signed=True, total_bits=8, add_zero=True):
sign = (-1.0 if signed else 0.0)
-
- values = torch.linspace(sign, 1.0, 2**total_bits)
+ total_values = 2**total_bits
+ if add_zero or total_bits < 8:
+ # add a zero
+ # since we simulate less bits by having zeros in the data type, we
+ # we need to center the quantization around zero and as such lose
+ # a single value
+ total_values = (2**total_bits if not signed else 2**total_bits-1)
+
+ values = torch.linspace(sign, 1.0, total_values)
gap = 256 - values.numel()
if gap == 0:
return values
@@ -155,20 +163,28 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
evalues.append(2**val)
- lst = list(itertools.product([0, 1], repeat=precision_bits))
- for bit_pattern in lst:
- value = 1
- for i, pval in enumerate(list(bit_pattern)):
- value += pval*(2**-(i+1))
- pvalues.append(value)
-
- assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
values = []
- for ev in evalues:
- for pv in pvalues:
+ lst = list(itertools.product([0, 1], repeat=precision_bits))
+ #for ev in evalues:
+ bias = 2**(exponent_bits-1)-1
+ for evalue in range(2**(exponent_bits)):
+ for bit_pattern in lst:
+ value = (1 if evalue != 0 else 0)
+ for i, pval in enumerate(list(bit_pattern)):
+ value += pval*(2**-(i+1))
+ if evalue == 0:
+ # subnormals
+ value = value*2**-(bias-1)
+ else:
+ # normals
+ value = value*2**-(evalue-bias-2)
+ values.append(value)
if signed:
- values.append(-ev*pv)
- values.append(ev*pv)
+ values.append(-value)
+
+
+ assert len(values) == 2**total_bits
+ values.sort()
if total_bits < 8:
gap = 256 - len(values)
for i in range(gap):
@@ -176,7 +192,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values.sort()
code = torch.Tensor(values)
code /= code.max()
- code[127] = 0
return code
@@ -232,6 +247,20 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.sort()
return Tensor(data)
+def create_quantile_map(A, total_bits=8):
+ q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
+ q = q.tolist()
+ q.append(0)
+
+ gap = 256 - len(q)
+ for i in range(gap):
+ q.append(0)
+
+ q.sort()
+
+ q = Tensor(q)
+ q = q/q.abs().max()
+ return q
def get_special_format_str():
if not torch.cuda.is_available(): return 'col_turing'
@@ -422,6 +451,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
post_call(device)
if num_quantiles < 256:
+ step = round(256/num_quantiles)
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
out = out[idx]