diff options
-rw-r--r-- | bitsandbytes/cextension.py | 1 | ||||
-rw-r--r-- | bitsandbytes/functional.py | 22 | ||||
-rw-r--r-- | csrc/kernels.cu | 28 | ||||
-rw-r--r-- | csrc/kernels.cuh | 2 | ||||
-rw-r--r-- | tests/test_functional.py | 16 |
5 files changed, 44 insertions, 25 deletions
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index ead8502..264e899 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -94,7 +94,6 @@ class CUDASetup(object): else: self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") self.lib = ct.cdll.LoadLibrary(binary_path) - print(self.lib) except Exception as ex: self.add_log_entry(str(ex)) self.print_log_stack() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6278db9..fffbecf 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -458,16 +458,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra """ - prev_device = pre_call(A.device) if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - code = code.to(A.device) if absmax is None: n = A.numel() - blocksize = (blocksize if A.device.type == 'cuda' else 4096) blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) @@ -477,8 +474,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra if A.device.type != 'cpu': assert blocksize in [4096, 2048, 1024, 512] - is_on_gpu([code, A, absmax, out, rand]) cblocksize = ct.c_int32(blocksize) + prev_device = pre_call(A.device) + code = code.to(A.device) if rand is not None: is_on_gpu([code, A, out, absmax, rand]) assert blocksize==4096 @@ -498,11 +496,12 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) else: # cpu + code = code.cpu() assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - post_call(A.device) return out, (absmax, code) @@ -541,32 +540,35 @@ def dequantize_blockwise( Dequantized tensor (default: float32) """ assert quant_state is not None or absmax is not None - device = pre_call(A.device) if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) code = name2qmap["dynamic"] - code = code.to(A.device) if out is None: out = torch.zeros_like(A, dtype=torch.float32) if quant_state is None: quant_state = (absmax, code) + else: + absmax, code = quant_state if A.device.type != 'cpu': + device = pre_call(A.device) + code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512]") is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) else: + code = code.cpu() lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - post_call(A.device) return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9d9653c..4c750d1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -510,7 +510,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> -__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n) +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; @@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - __shared__ float smem_code[256]; + //__shared__ float smem_code[256]; + //float local_code[16]; - if(threadIdx.x < 256) - smem_code[threadIdx.x] = code[threadIdx.x]; + //if(threadIdx.x < 256) + //smem_code[threadIdx.x] = code[threadIdx.x]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = smem_code[qvals[j]]*local_abs_max; + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; __syncthreads(); StoreT(storet).Store(&(out[i]), vals, valid_items); @@ -2798,14 +2800,14 @@ template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, flo template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index bdf61b2..cca983b 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -15,7 +15,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n); +template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, diff --git a/tests/test_functional.py b/tests/test_functional.py index 4642b16..d36dfc1 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2166,3 +2166,19 @@ def test_kbit_quantile_estimation(): val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) err = torch.abs(val1-val2).mean() assert err < 0.035 + + +def test_bench_dequantization(): + a = torch.rand(1024, 1024, device='cuda').half() + qa, SA = F.quantize_blockwise(a) + + max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 + #print(max_theoretical_mu) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + F.dequantize_blockwise(qa, SA, blocksize=2048) + torch.cuda.synchronize() + #print((time.time()-t0)/1e6) + |