From cbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 12:12:38 -0700 Subject: Boilerplate and test for extract_outliers. --- csrc/kernels.cu | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1c3e723..78170d0 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2592,10 +2592,17 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -- cgit v1.2.3 From bcab99ec877ba063543bd7c03ba1cdd1b06e8078 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 17:39:30 -0700 Subject: Working outlier extraction for Turing. --- csrc/kernels.cu | 61 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 3 deletions(-) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 78170d0..bb36d9b 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2592,16 +2592,71 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) { + int local_colidx = idx[blockIdx.x]; + + if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + + char val = 0; + //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); + if(offset > tiledColsA*tiledRowsA) + printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); + else + val = A[offset]; + + int out_idx = (row*idx_size) + blockIdx.x; + + //if(out_idx > colsA*idx_size) + if(val != 0) + { + //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val); + out[out_idx] = val; + } + else + { + out[out_idx] = val; + } + } + + } } //============================================================== // TEMPLATE DEFINITIONS //============================================================== -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -- cgit v1.2.3 From 32fa459ed7c812c79e847145004061f21b7ac0d9 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 18:15:51 -0700 Subject: Added col_ampere outlier extraction kernel. --- csrc/kernels.cu | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index bb36d9b..79ad5de 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2626,28 +2626,32 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * offset += tile_offset_rows + tile_offset_cols; - - char val = 0; - //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); - if(offset > tiledColsA*tiledRowsA) - printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); - else - val = A[offset]; + char val = A[offset]; int out_idx = (row*idx_size) + blockIdx.x; - - //if(out_idx > colsA*idx_size) - if(val != 0) - { - //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val); - out[out_idx] = val; - } - else - { - out[out_idx] = val; - } + out[out_idx] = val; } + } + else if(FORMAT == COL_AMPERE) + { + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } } } -- cgit v1.2.3