From 1ed2fa2f218d8dac401f3315420ffec92014c124 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 10:56:17 -0700 Subject: Removed storage() from get_ptr; added boilerplate for bias dequant_mm. --- csrc/ops.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'csrc/ops.cu') diff --git a/csrc/ops.cu b/csrc/ops.cu index c16dd96..ed32828 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple) return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); } -void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) { int threads = 512; int tileCols = fill_up_to_nearest_multiple(numCols, 32); @@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, num_blocks = num_blocks*(tileCols/32); assert(threads <= tilesize); - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r col_tiles = col_tiles > 0 ? col_tiles : 1; int num_blocks = row_tiles * col_tiles; - if(nnz_threshold == 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) -- cgit v1.2.3