diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-06 11:47:54 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-06 11:47:54 -0800 |
commit | caf1832526e4ad54ae8fe8e947f19ed690f35a40 (patch) | |
tree | 29603912cf04146b75d868b6b905a0f857b7bf62 /bitsandbytes | |
parent | 1efb87d89d1c3fe532eb97847c3b48fd1a8e5d83 (diff) |
Added k-bit linear quantization.
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/functional.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index d7e186f..65eccf2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,11 +130,17 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True): - if signed: - return torch.linspace(-1.0, 1.0, 256) +def create_linear_map(signed=True, bits=8): + sign = (-1.0 if signed else 0.0) + + values = torch.linspace(sign, 1.0, 2**bits) + gap = 256 - values.numel() + if gap == 0: + return values else: - return torch.linspace(0.0, 1.0, 256) + l = values.numel()//2 + #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) + return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): |