summaryrefslogtreecommitdiff
path: root/csrc/ops.cu
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r--csrc/ops.cu26
1 files changed, 26 insertions, 0 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index cfc9605..9b75e69 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -618,10 +618,36 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
+
+template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
+{
+ int threads = 256;
+ // we load 128 column values per warp
+ int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
+ int tiledRows = 0;
+
+ int num_blocks = idx_size;
+
+ if(FORMAT == COL_TURING)
+ {
+ tiledRows = fill_up_to_nearest_multiple(rows, 8);
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ tiledRows = fill_up_to_nearest_multiple(rows, 32);
+ }
+
+ kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);