summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-06 16:27:48 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-06 16:27:48 -0800
commit6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 (patch)
treeceded1197e0e1335b7e8cd7ed7134c362115573e
parent2f2063bac212bcd6a515a88a12a9530b5730dabe (diff)
Added blocksizes 2048, 1024, and 512 to blockwise quant.
-rw-r--r--bitsandbytes/cextension.py11
-rw-r--r--bitsandbytes/functional.py20
-rw-r--r--csrc/kernels.cu22
-rw-r--r--csrc/ops.cu33
-rw-r--r--csrc/ops.cuh2
-rw-r--r--csrc/pythonInterface.c12
-rw-r--r--tests/test_functional.py92
7 files changed, 113 insertions, 79 deletions
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 8125202..ead8502 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -52,8 +52,13 @@ class CUDASetup(object):
self.add_log_entry('python setup.py install')
def initialize(self):
- self.cuda_setup_log = []
+ self.has_printed = False
self.lib = None
+ self.run_cuda_setup()
+
+ def run_cuda_setup(self):
+ self.initialized = True
+ self.cuda_setup_log = []
from .cuda_setup.main import evaluate_cuda_setup
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
@@ -89,7 +94,9 @@ class CUDASetup(object):
else:
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path)
- except:
+ print(self.lib)
+ except Exception as ex:
+ self.add_log_entry(str(ex))
self.print_log_stack()
def add_log_entry(self, msg, is_warning=False):
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 076414d..49d4db1 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -130,10 +130,10 @@ class Cusparse_Context(object):
return cls._instance
-def create_linear_map(signed=True, bits=8):
+def create_linear_map(signed=True, total_bits=8):
sign = (-1.0 if signed else 0.0)
- values = torch.linspace(sign, 1.0, 2**bits)
+ values = torch.linspace(sign, 1.0, 2**total_bits)
gap = 256 - values.numel()
if gap == 0:
return values
@@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
The quantization state to undo the quantization.
"""
+
if code is None:
if "dynamic" not in name2qmap:
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
@@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
out = torch.zeros_like(A, dtype=torch.uint8)
if A.device.type != 'cpu':
+ assert blocksize in [4096, 2048, 1024, 512]
is_on_gpu([code, A, absmax, out, rand])
+ cblocksize = ct.c_int32(blocksize)
if rand is not None:
+ assert blocksize==4096
assert rand.numel() >= 1024
rand_offset = random.randint(0, 1023)
if A.dtype == torch.float32:
@@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
elif A.dtype == torch.float16:
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
else:
- raise ValueError(
- f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
- )
+ raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
if A.dtype == torch.float32:
- lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
elif A.dtype == torch.float16:
- lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
+ lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
else:
- raise ValueError(
- f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
- )
+ raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
else:
# cpu
assert rand is None
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index f01b4e1..9d9653c 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
-__launch_bounds__(TH, 4)
+//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
- T vals[NUM];
- float rand_vals[NUM];
- unsigned char qvals[NUM];
+ T vals[NUM_PER_TH];
+ float rand_vals[NUM_PER_TH];
+ unsigned char qvals[NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
@@ -517,8 +517,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
- T vals[NUM];
- unsigned char qvals[NUM];
+ T vals[NUM_PER_TH];
+ unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
@@ -2791,11 +2791,21 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
+template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
+template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
+template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
+template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
diff --git a/csrc/ops.cu b/csrc/ops.cu
index e49c94b..b121fc2 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
-template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
+template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{
- int num_blocks = n/4096;
- num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
- kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
+ int num_blocks = n/blocksize;
+ num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
+ if(STOCHASTIC == 1)
+ assert(blocksize == 4096);
+
+ if(blocksize == 4096)
+ kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
+ else if(blocksize == 2048)
+ kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
+ else if(blocksize == 1024)
+ kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
+ else if(blocksize == 512)
+ kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
+
+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -66,6 +78,11 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
else if(blocksize == 2048)
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
+ else if(blocksize == 1024)
+ kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
+ else if(blocksize == 512)
+ kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -659,10 +676,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
-template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
-template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
-template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
-template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
+template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
+template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
+template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index acfdb06..66e3843 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -128,7 +128,7 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
void quantize(float *code, float *A, unsigned char *out, int n);
void dequantize(float *code, unsigned char *A, float *out, int n);
-template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
+template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 58e26a9..5bac30e 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
-void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
-void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
-void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
-void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
+void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
+void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
+void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
+void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
@@ -140,8 +140,8 @@ extern "C"
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
- void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
- void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
+ void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
+ void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 99885da..b525dff 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -151,30 +151,41 @@ def test_dynamic_quantization():
def test_dynamic_blockwise_quantization():
- diffs = []
- reldiffs = []
- for i in range(100):
- A1 = torch.randn(1024, 1024, device="cuda")
- C, S = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1 - A2)
- reldiff = diff / torch.abs(A1 + 1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- assert diffs[-1] < 0.011
- # print(sum(diffs)/len(diffs))
- # print(sum(reldiffs)/len(reldiffs))
-
- diffs = []
- for i in range(100):
- A1 = torch.rand(1024, 1024, device="cuda")
- C, S = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1 - A2).mean().item()
- assert diff < 0.0033
- diffs.append(diff)
- torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
- # print(sum(diffs)/len(diffs))
+ #print('')
+ for blocksize in [4096, 2048, 1024, 512]:
+ diffs = []
+ reldiffs = []
+ for i in range(100):
+ A1 = torch.randn(1024, 1024, device="cuda")
+ C, S = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ abserr = sum(diffs)/len(diffs)
+ relerr = sum(reldiffs)/len(reldiffs)
+ assert abserr < 0.011
+ assert relerr < 0.018
+ #print('randn', blocksize, sum(diffs)/len(diffs))
+ #print('randn', blocksize, sum(reldiffs)/len(reldiffs))
+
+ diffs = []
+ for i in range(100):
+ A1 = torch.rand(1024, 1024, device="cuda")
+ C, S = F.quantize_blockwise(A1)
+ A2 = F.dequantize_blockwise(C, S)
+ diff = torch.abs(A1 - A2)
+ reldiff = diff / torch.abs(A1 + 1e-8)
+ diffs.append(diff.mean().item())
+ reldiffs.append(reldiff.mean().item())
+ torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
+ abserr = sum(diffs)/len(diffs)
+ relerr = sum(reldiffs)/len(reldiffs)
+ assert abserr < 0.0035
+ assert relerr < 0.015
+ #print('rand', blocksize, sum(diffs)/len(diffs))
+ #print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():
@@ -1618,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
# print(time.time() - t0)
-def test_layout():
- a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
- a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
- a2, s2 = F.transform(a1, "col_turing")
- print(a2.shape)
-
- print(a1.flatten()[8 * 64 : 8 * 64 + 32])
- for i in range(4):
- print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
-
-
def test_coo2csr():
threshold = 1
A = torch.randn(128, 128).half().cuda()
@@ -2062,8 +2062,8 @@ def test_fp8_quant():
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
- print(sum(abserr)/len(abserr))
- print(sum(relerr)/len(relerr))
+ #print(sum(abserr)/len(abserr))
+ #print(sum(relerr)/len(relerr))
abserr = []
relerr = []
@@ -2076,8 +2076,8 @@ def test_fp8_quant():
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
- print(sum(abserr)/len(abserr))
- print(sum(relerr)/len(relerr))
+ #print(sum(abserr)/len(abserr))
+ #print(sum(relerr)/len(relerr))
abserr = []
relerr = []
@@ -2090,21 +2090,21 @@ def test_fp8_quant():
abserr.append(diff.mean().item())
relerr.append(reldiff.mean().item())
#assert diff < 0.0075
- print(3, sum(abserr)/len(abserr))
- print(3, sum(relerr)/len(relerr))
+ #print(3, sum(abserr)/len(abserr))
+ #print(3, sum(relerr)/len(relerr))
def test_few_bit_quant():
- print('')
+ #print('')
for bits in range(2, 9):
- print('='*30, bits, '='*30)
+ #print('='*30, bits, '='*30)
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
abserrs = []
relerrs = []
code = None
if method == 'linear':
- code = F.create_linear_map(True, bits=bits).cuda()
+ code = F.create_linear_map(True, total_bits=bits).cuda()
elif method == 'fp8':
ebits = math.ceil(bits/2)
pbits = bits-ebits-1
@@ -2122,7 +2122,7 @@ def test_few_bit_quant():
q /= q.abs().max()
code, idx = torch.sort(q)
- print(method, (code==0).sum())
+ #print(method, (code==0).sum())
assert code.numel() == 256
for i in range(10):
@@ -2154,7 +2154,7 @@ def test_few_bit_quant():
else:
torch.testing.assert_allclose(q1, q2)
- print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
+ #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
def test_kbit_quantile_estimation():