summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/functional.py14
-rw-r--r--tests/test_functional.py50
2 files changed, 60 insertions, 4 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index d7e186f..65eccf2 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -130,11 +130,17 @@ class Cusparse_Context(object):
return cls._instance
-def create_linear_map(signed=True):
- if signed:
- return torch.linspace(-1.0, 1.0, 256)
+def create_linear_map(signed=True, bits=8):
+ sign = (-1.0 if signed else 0.0)
+
+ values = torch.linspace(sign, 1.0, 2**bits)
+ gap = 256 - values.numel()
+ if gap == 0:
+ return values
else:
- return torch.linspace(0.0, 1.0, 256)
+ l = values.numel()//2
+ #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
+ return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 329b270..494bf51 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -2091,3 +2091,53 @@ def test_fp8_quant():
print(3, sum(abserr)/len(abserr))
print(3, sum(relerr)/len(relerr))
+
+def test_few_bit_quant():
+
+ for bits in range(2, 9):
+ code = F.create_linear_map(True, bits=bits).cuda()
+ assert code.numel() == 256
+ print(bits)
+ for i in range(100):
+
+ values = torch.randn(1, 24, device='cuda')
+ values /= values.abs().max()
+ #values[values.abs() < 1e-6] += 1e-5
+
+ q1 = []
+ v1 = []
+ for v in values[0]:
+ idx = torch.abs(v-code).argmin()
+ q1.append(idx.item())
+ v1.append(code[idx].item())
+
+ q1 = torch.Tensor(q1).cuda()
+ v1 = torch.Tensor(v1).cuda()
+
+ q2, S2 = F.quantize(values, code=code)
+ v2 = F.dequantize(q2, S2)
+
+ idx = torch.isclose(q1.int(), q2.int())
+ if idx.sum():
+ # some weird cases
+ err1 = torch.abs(v1-values).mean()
+ err2 = torch.abs(v2-values).mean()
+ assert err2 <= err1
+
+ else:
+ torch.testing.assert_allclose(q1, q2)
+
+ #print(e_bits, p_bits)
+ #abserr = []
+ #relerr = []
+ #for i in range(100):
+ # A1 = torch.randn(1024, 1024, device="cuda")
+ # C, SC = F.quantize_blockwise(A1, code=code)
+ # A2 = F.dequantize_blockwise(C, SC)
+ # diff = torch.abs(A1 - A2)
+ # reldiff = diff/torch.abs(A1+1e-8)
+ # abserr.append(diff.mean().item())
+ # relerr.append(reldiff.mean().item())
+ # #assert diff < 0.0075
+ #print(sum(abserr)/len(abserr))
+ #print(sum(relerr)/len(relerr))