From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- csrc/ops.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'csrc/ops.cuh') diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 4b09ecf..acfdb06 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -163,7 +163,7 @@ template int igemmlt(cublasLtHandle template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); -- cgit v1.2.3