summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu61
-rw-r--r--csrc/kernels.cuh2
-rw-r--r--csrc/ops.cu5
3 files changed, 61 insertions, 7 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 78170d0..bb36d9b 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2592,16 +2592,71 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
-template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
{
+ int local_colidx = idx[blockIdx.x];
+
+ if(FORMAT==COL_TURING)
+ {
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+ // columns are grouped in increments of 4, meaning that one has the following rows and columns
+ // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
+ // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...]
+
+ // each thread reads 1 element = 1 row
+ for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
+ {
+ int offset_per_col_tile = ((rowsA+7)/8)*32*8;
+ int tile_offset_rows = (row/8)*32*8;
+ int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
+ int offset = 0;
+ int subtile_col_idx = local_colidx%32;
+ int subtile_row_idx = row % 8;
+ if(row % 2 == 1)
+ offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2);
+ else
+ // even
+ offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2);
+
+ offset += tile_offset_rows + tile_offset_cols;
+
+
+ char val = 0;
+ //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
+ if(offset > tiledColsA*tiledRowsA)
+ printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
+ else
+ val = A[offset];
+
+ int out_idx = (row*idx_size) + blockIdx.x;
+
+ //if(out_idx > colsA*idx_size)
+ if(val != 0)
+ {
+ //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val);
+ out[out_idx] = val;
+ }
+ else
+ {
+ out[out_idx] = val;
+ }
+ }
+
+ }
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
-template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
-template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index ec2068e..eda5ba0 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -118,7 +118,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif
diff --git a/csrc/ops.cu b/csrc/ops.cu
index fe2d7fe..e6227ae 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -586,8 +586,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
int tiledRows = 0;
- int elements = idx_size*cols; // matrix A is transposed, so we extract columns
- int num_blocks = (elements+threads-1)/threads;
+ int num_blocks = idx_size;
if(FORMAT == COL_TURING)
{
@@ -598,7 +597,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
tiledRows = fill_up_to_nearest_multiple(rows, 32);
}
- kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, rows, cols, tiledRows, tiledCols);
+ kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}