From c771b3a75a6ebbfbfc398a028a477246b0799cf0 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 14:41:05 -0700 Subject: Most tests passing. --- csrc/kernels.cuh | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'csrc/kernels.cuh') diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 0a3676c..cbfbeba 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -106,6 +106,18 @@ template __global__ void kPercentileCl __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + #endif -- cgit v1.2.3