diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-19 07:24:03 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-19 07:24:03 -0800 |
commit | eb028e6ebcddc78c7921c2524d361b23b1a1007b (patch) | |
tree | 168ea8943ed732b02e6bce171cfa11f8d935b938 /bitsandbytes | |
parent | 08fa2e7b01dda8959a930295de9829516f8c77bc (diff) |
Fixed k-bit quantization maps.
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 62 |
1 files changed, 46 insertions, 16 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index fffbecf..d9249b1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -7,6 +7,7 @@ import operator import random import torch import itertools +import math from typing import Tuple from torch import Tensor @@ -130,10 +131,17 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True, total_bits=8): +def create_linear_map(signed=True, total_bits=8, add_zero=True): sign = (-1.0 if signed else 0.0) - - values = torch.linspace(sign, 1.0, 2**total_bits) + total_values = 2**total_bits + if add_zero or total_bits < 8: + # add a zero + # since we simulate less bits by having zeros in the data type, we + # we need to center the quantization around zero and as such lose + # a single value + total_values = (2**total_bits if not signed else 2**total_bits-1) + + values = torch.linspace(sign, 1.0, total_values) gap = 256 - values.numel() if gap == 0: return values @@ -155,20 +163,28 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) evalues.append(2**val) - lst = list(itertools.product([0, 1], repeat=precision_bits)) - for bit_pattern in lst: - value = 1 - for i, pval in enumerate(list(bit_pattern)): - value += pval*(2**-(i+1)) - pvalues.append(value) - - assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign) values = [] - for ev in evalues: - for pv in pvalues: + lst = list(itertools.product([0, 1], repeat=precision_bits)) + #for ev in evalues: + bias = 2**(exponent_bits-1)-1 + for evalue in range(2**(exponent_bits)): + for bit_pattern in lst: + value = (1 if evalue != 0 else 0) + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + if evalue == 0: + # subnormals + value = value*2**-(bias-1) + else: + # normals + value = value*2**-(evalue-bias-2) + values.append(value) if signed: - values.append(-ev*pv) - values.append(ev*pv) + values.append(-value) + + + assert len(values) == 2**total_bits + values.sort() if total_bits < 8: gap = 256 - len(values) for i in range(gap): @@ -176,7 +192,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values.sort() code = torch.Tensor(values) code /= code.max() - code[127] = 0 return code @@ -232,6 +247,20 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): data.sort() return Tensor(data) +def create_quantile_map(A, total_bits=8): + q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = q.tolist() + q.append(0) + + gap = 256 - len(q) + for i in range(gap): + q.append(0) + + q.sort() + + q = Tensor(q) + q = q/q.abs().max() + return q def get_special_format_str(): if not torch.cuda.is_available(): return 'col_turing' @@ -422,6 +451,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n post_call(device) if num_quantiles < 256: + step = round(256/num_quantiles) idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) out = out[idx] |