summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-06 11:59:37 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-06 11:59:37 -0800
commit98cbc4bc4f15f5c094cd8575ddb0380a19516099 (patch)
tree11c11ae987ce81241f4a23d8c841a546c11d556f
parentcaf1832526e4ad54ae8fe8e947f19ed690f35a40 (diff)
Added k-bit fp8 map.
-rw-r--r--bitsandbytes/functional.py16
-rw-r--r--tests/test_functional.py88
2 files changed, 52 insertions, 52 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 65eccf2..ff48b7f 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -143,14 +143,15 @@ def create_linear_map(signed=True, bits=8):
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
-def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
+def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
e = exponent_bits
p = precision_bits
- assert e+p == 7
+ has_sign = 1 if signed else 0
+ assert e+p == total_bits-has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
pvalues = []
- for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
+ for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
evalues.append(2**val)
@@ -161,12 +162,17 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
value += pval*(2**-(i+1))
pvalues.append(value)
- assert len(evalues)*len(pvalues) == 128
+ assert len(evalues)*len(pvalues) == 2**(total_bits-has_sign)
values = []
for ev in evalues:
for pv in pvalues:
- values.append(-ev*pv)
+ if signed:
+ values.append(-ev*pv)
values.append(ev*pv)
+ if total_bits < 8:
+ gap = 256 - len(values)
+ for i in range(gap):
+ values.append(0)
values.sort()
code = torch.Tensor(values)
code /= code.max()
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 494bf51..bd4dafe 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -11,7 +11,7 @@ import bitsandbytes as bnb
from bitsandbytes import functional as F
torch.set_printoptions(
- precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
+ precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
k = 20
@@ -2095,49 +2095,43 @@ def test_fp8_quant():
def test_few_bit_quant():
for bits in range(2, 9):
- code = F.create_linear_map(True, bits=bits).cuda()
- assert code.numel() == 256
- print(bits)
- for i in range(100):
-
- values = torch.randn(1, 24, device='cuda')
- values /= values.abs().max()
- #values[values.abs() < 1e-6] += 1e-5
-
- q1 = []
- v1 = []
- for v in values[0]:
- idx = torch.abs(v-code).argmin()
- q1.append(idx.item())
- v1.append(code[idx].item())
-
- q1 = torch.Tensor(q1).cuda()
- v1 = torch.Tensor(v1).cuda()
-
- q2, S2 = F.quantize(values, code=code)
- v2 = F.dequantize(q2, S2)
-
- idx = torch.isclose(q1.int(), q2.int())
- if idx.sum():
- # some weird cases
- err1 = torch.abs(v1-values).mean()
- err2 = torch.abs(v2-values).mean()
- assert err2 <= err1
-
- else:
- torch.testing.assert_allclose(q1, q2)
-
- #print(e_bits, p_bits)
- #abserr = []
- #relerr = []
- #for i in range(100):
- # A1 = torch.randn(1024, 1024, device="cuda")
- # C, SC = F.quantize_blockwise(A1, code=code)
- # A2 = F.dequantize_blockwise(C, SC)
- # diff = torch.abs(A1 - A2)
- # reldiff = diff/torch.abs(A1+1e-8)
- # abserr.append(diff.mean().item())
- # relerr.append(reldiff.mean().item())
- # #assert diff < 0.0075
- #print(sum(abserr)/len(abserr))
- #print(sum(relerr)/len(relerr))
+ for method in ['linear', 'fp8']:
+ code = None
+ if method == 'linear':
+ code = F.create_linear_map(True, bits=bits).cuda()
+ elif method == 'fp8':
+ ebits = math.ceil(bits/2)
+ pbits = bits-ebits-1
+ code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
+ print(ebits, pbits, bits)
+ print(code)
+ assert code.numel() == 256
+ print(bits)
+ for i in range(10):
+
+ values = torch.randn(1, 32, device='cuda')
+ values /= values.abs().max()
+ #values[values.abs() < 1e-6] += 1e-5
+
+ q1 = []
+ v1 = []
+ for v in values[0]:
+ idx = torch.abs(v-code).argmin()
+ q1.append(idx.item())
+ v1.append(code[idx].item())
+
+ q1 = torch.Tensor(q1).cuda()
+ v1 = torch.Tensor(v1).cuda()
+
+ q2, S2 = F.quantize(values, code=code)
+ v2 = F.dequantize(q2, S2)
+
+ idx = torch.isclose(q1.int(), q2.int())
+ if idx.sum():
+ # some weird cases
+ err1 = torch.abs(v1-values).mean()
+ err2 = torch.abs(v2-values).mean()
+ assert err2 <= err1
+
+ else:
+ torch.testing.assert_allclose(q1, q2)