From eb028e6ebcddc78c7921c2524d361b23b1a1007b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 19 Nov 2022 07:24:03 -0800 Subject: Fixed k-bit quantization maps. --- bitsandbytes/functional.py | 62 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 16 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index fffbecf..d9249b1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -7,6 +7,7 @@ import operator import random import torch import itertools +import math from typing import Tuple from torch import Tensor @@ -130,10 +131,17 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True, total_bits=8): +def create_linear_map(signed=True, total_bits=8, add_zero=True): sign = (-1.0 if signed else 0.0) - - values = torch.linspace(sign, 1.0, 2**total_bits) + total_values = 2**total_bits + if add_zero or total_bits < 8: + # add a zero + # since we simulate less bits by having zeros in the data type, we + # we need to center the quantization around zero and as such lose + # a single value + total_values = (2**total_bits if not signed else 2**total_bits-1) + + values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values @@ -155,20 +163,28 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) evalues.append(2**val) - lst = list(itertools.product([0, 1], repeat=precision_bits)) - for bit_pattern in lst: - value = 1 - for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) - pvalues.append(value) - - assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign) values = [] - for ev in evalues: - for pv in pvalues: + lst = list(itertools.product([0, 1], repeat=precision_bits)) + #for ev in evalues: + bias = 2**(exponent_bits-1)-1 + for evalue in range(2**(exponent_bits)): + for bit_pattern in lst: + value = (1 if evalue != 0 else 0) + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + if evalue == 0: + # subnormals + value = value*2**-(bias-1) + else: + # normals + value = value*2**-(evalue-bias-2) + values.append(value) if signed: - values.append(-ev*pv) - values.append(ev*pv) + values.append(-value) + + + assert len(values) == 2**total_bits + values.sort() if total_bits < 8: gap = 256 - len(values) for i in range(gap): @@ -176,7 +192,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values.sort() code = torch.Tensor(values) code /= code.max() - code[127] = 0 return code @@ -232,6 +247,20 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) +def create_quantile_map(A, total_bits=8): + q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = q.tolist() + q.append(0) + + gap = 256 - len(q) + for i in range(gap): + q.append(0) + + q.sort() + + q = Tensor(q) + q = q/q.abs().max() + return q def get_special_format_str(): if not torch.cuda.is_available(): return 'col_turing' @@ -422,6 +451,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n post_call(device) if num_quantiles < 256: + step = round(256/num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] -- cgit v1.2.3