From 08fa2e7b01dda8959a930295de9829516f8c77bc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 7 Nov 2022 18:06:18 -0800 Subject: Fixed bug in cpu quant; faster GPU dequant. --- csrc/kernels.cu | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 9d9653c..4c750d1 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -510,7 +510,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float } template -__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n) +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; @@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - __shared__ float smem_code[256]; + //__shared__ float smem_code[256]; + //float local_code[16]; - if(threadIdx.x < 256) - smem_code[threadIdx.x] = code[threadIdx.x]; + //if(threadIdx.x < 256) + //smem_code[threadIdx.x] = code[threadIdx.x]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + // load code through read-only cache via __ldg #pragma unroll NUM_PER_TH for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = smem_code[qvals[j]]*local_abs_max; + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; __syncthreads(); StoreT(storet).Store(&(out[i]), vals, valid_items); @@ -2798,14 +2800,14 @@ template __global__ void kQuantizeBlockwise(float * code, flo template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -- cgit v1.2.3