summaryrefslogtreecommitdiff
path: root/csrc/ops.cu
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-16 10:56:17 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-16 10:56:17 -0700
commit1ed2fa2f218d8dac401f3315420ffec92014c124 (patch)
tree57863d4d1024689100c1b43caccc1d8739c58d99 /csrc/ops.cu
parent26efb154c8d77b4ede2cfc0dbd2381dd385f33e7 (diff)
Removed storage() from get_ptr; added boilerplate for bias dequant_mm.
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r--csrc/ops.cu5
1 files changed, 2 insertions, 3 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index c16dd96..ed32828 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple)
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
-void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
{
int threads = 512;
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
@@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
num_blocks = num_blocks*(tileCols/32);
assert(threads <= tilesize);
- kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
+ kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
-
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)