summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu7
-rw-r--r--csrc/kernels.cuh2
-rw-r--r--csrc/ops.cu27
-rw-r--r--csrc/ops.cuh2
-rw-r--r--csrc/pythonInterface.c6
5 files changed, 44 insertions, 0 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 1c3e723..78170d0 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2592,10 +2592,17 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}
}
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
+{
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index cbfbeba..ec2068e 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -118,6 +118,8 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
+
#endif
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 8946015..fe2d7fe 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -578,10 +578,37 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
+
+template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols)
+{
+ int threads = 256;
+ // we load 128 column values per warp
+ int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32);
+ int tiledRows = 0;
+
+ int elements = idx_size*cols; // matrix A is transposed, so we extract columns
+ int num_blocks = (elements+threads-1)/threads;
+
+ if(FORMAT == COL_TURING)
+ {
+ tiledRows = fill_up_to_nearest_multiple(rows, 8);
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ tiledRows = fill_up_to_nearest_multiple(rows, 32);
+ }
+
+ kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, rows, cols, tiledRows, tiledCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 4e719df..4b09ecf 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
+
#endif
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 03c8d92..2ecbaae 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -106,6 +106,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
+void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
+void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
+
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
@@ -280,6 +283,9 @@ extern "C"
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+ void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
+ void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
+
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }