summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu78
1 files changed, 72 insertions, 6 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 6eca3aa..d4eb56c 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2591,16 +2591,82 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
+{
+ int local_colidx = idx[blockIdx.x];
+
+ if(FORMAT==COL_TURING)
+ {
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+ // columns are grouped in increments of 4, meaning that one has the following rows and columns
+ // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
+ // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
+
+ // each thread reads 1 element = 1 row
+ for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
+ {
+ int offset_per_col_tile = ((rowsA+7)/8)*32*8;
+ int tile_offset_rows = (row/8)*32*8;
+ int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
+ int offset = 0;
+ int subtile_col_idx = local_colidx%32;
+ int subtile_row_idx = row % 8;
+ if(row % 2 == 1)
+ offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
+ else
+ // even
+ offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);
+
+ offset += tile_offset_rows + tile_offset_cols;
+
+ char val = A[offset];
+
+ int out_idx = (row*idx_size) + blockIdx.x;
+ out[out_idx] = val;
+ }
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+
+ for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
+ {
+ // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
+ // within each tile.
+ int offset_per_col_tile = ((rowsA+31)/32)*32*32;
+ int tile_offset_rows = (row/32)*32*32;
+ int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
+ int subtile_col_idx = local_colidx%32;
+ int subtile_row_idx = row % 32;
+ // this magic is taken from the cublasLt doc (search for COL32)
+ int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
+ offset += tile_offset_cols + tile_offset_rows;
+
+ char val = A[offset];
+ int out_idx = (row*idx_size) + blockIdx.x;
+ out[out_idx] = val;
+ }
+ }
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
-template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+
+template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);