summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-11-06 16:27:48 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2022-11-06 16:27:48 -0800
commit6bc2b992be0bb7511ea881f8ebbbd2ba7f1b5109 (patch)
treeceded1197e0e1335b7e8cd7ed7134c362115573e /csrc/kernels.cu
parent2f2063bac212bcd6a515a88a12a9530b5730dabe (diff)
Added blocksizes 2048, 1024, and 512 to blockwise quant.
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu22
1 files changed, 16 insertions, 6 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index f01b4e1..9d9653c 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
-__launch_bounds__(TH, 4)
+//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
- T vals[NUM];
- float rand_vals[NUM];
- unsigned char qvals[NUM];
+ T vals[NUM_PER_TH];
+ float rand_vals[NUM_PER_TH];
+ unsigned char qvals[NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
@@ -517,8 +517,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
- T vals[NUM];
- unsigned char qvals[NUM];
+ T vals[NUM_PER_TH];
+ unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
@@ -2791,11 +2791,21 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
+template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
+template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
+template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
+template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
+template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);