From 6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 6 Nov 2022 16:27:48 -0800 Subject: Added blocksizes 2048, 1024, and 512 to blockwise quant. --- csrc/kernels.cu | 22 ++++++++++++++++------ csrc/ops.cu | 33 +++++++++++++++++++++++++-------- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 12 ++++++------ 4 files changed, 48 insertions(+), 21 deletions(-) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index f01b4e1..9d9653c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } template -__launch_bounds__(TH, 4) +//__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { const int n_full = gridDim.x * BLOCK_SIZE; int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM]; - float rand_vals[NUM]; - unsigned char qvals[NUM]; + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; @@ -517,8 +517,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c int valid_items = 0; const int base_idx = (blockIdx.x * BLOCK_SIZE); - T vals[NUM]; - unsigned char qvals[NUM]; + T vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; @@ -2791,11 +2791,21 @@ template __global__ void kQuantizeBlockwise(float * code, half template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index e49c94b..b121fc2 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,11 +50,23 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + if(STOCHASTIC == 1) + assert(blocksize == 4096); + + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -66,6 +78,11 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo kDequantizeBlockwise<<>>(code, A, absmax, out, n); else if(blocksize == 2048) kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 1024) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 512) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -659,10 +676,10 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index acfdb06..66e3843 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -128,7 +128,7 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 58e26a9..5bac30e 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } @@ -140,8 +140,8 @@ extern "C" void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } - void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); } - void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); } + void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } -- cgit v1.2.3