summaryrefslogtreecommitdiff
path: root/csrc/kernels.cuh
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e /csrc/kernels.cuh
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
Diffstat (limited to 'csrc/kernels.cuh')
-rw-r--r--csrc/kernels.cuh12
1 files changed, 12 insertions, 0 deletions
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 0a3676c..cbfbeba 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
+
+template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
+ int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
+ half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+
+template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+
#endif