From c059bd284832d09bc51cf82c377642b26a48ef28 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 20 Nov 2022 14:18:15 -0800 Subject: Added additional blocksizes: {64, 128, 256}. --- csrc/kernels.cu | 16 ++++++++++++++-- csrc/ops.cu | 12 ++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4c750d1..29f266a 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -454,8 +454,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; - if(threadIdx.x < 256) - smem_code[threadIdx.x] = code[threadIdx.x]; + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -2799,6 +2799,12 @@ template __global__ void kQuantizeBlockwise(float * code, half template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); @@ -2808,6 +2814,12 @@ template __global__ void kDequantizeBlockwise(float *code, u template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index b121fc2..30079e6 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -65,6 +65,12 @@ template void quantizeBlockwise(float * code, T *A, kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 64) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -82,6 +88,12 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo kDequantizeBlockwise<<>>(code, A, absmax, out, n); else if(blocksize == 512) kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 256) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 128) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); + else if(blocksize == 64) + kDequantizeBlockwise<<>>(code, A, absmax, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -- cgit v1.2.3