From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- csrc/kernels.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'csrc/kernels.cuh') diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 2447494..bdf61b2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -111,7 +111,7 @@ template __global__ void kspmm_coo_very_s template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, - half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); -- cgit v1.2.3