From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- csrc/kernels.cu | 4 ++-- csrc/kernels.cuh | 2 +- csrc/ops.cu | 5 ++--- csrc/ops.cuh | 2 +- csrc/pythonInterface.c | 4 ++-- 5 files changed, 8 insertions(+), 9 deletions(-) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d4eb56c..0a1bf79 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1889,7 +1889,7 @@ template __global__ void kgetColRowStats(half * __rest #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) -template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n) +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) { // Strategy: To dequantize we need to load col/row statistics. This can be very expensive @@ -2675,7 +2675,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>( template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); -template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 2447494..bdf61b2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -111,7 +111,7 @@ template __global__ void kspmm_coo_very_s template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); diff --git a/csrc/ops.cu b/csrc/ops.cu index c16dd96..ed32828 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple) return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) { int threads = 512; int tileCols = fill_up_to_nearest_multiple(numCols, 32); @@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, num_blocks = num_blocks*(tileCols/32); assert(threads <= tilesize); - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; - if(nnz_threshold == 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 4b09ecf..acfdb06 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -163,7 +163,7 @@ template int igemmlt(cublasLtHandle template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 7356c11..0707674 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -248,8 +248,8 @@ extern "C" MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) - void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) - { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); } + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } -- cgit v1.2.3