summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu4
-rw-r--r--csrc/kernels.cuh2
-rw-r--r--csrc/ops.cu5
-rw-r--r--csrc/ops.cuh2
-rw-r--r--csrc/pythonInterface.c4
5 files changed, 8 insertions, 9 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index d4eb56c..0a1bf79 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1889,7 +1889,7 @@ template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __rest
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
-template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n)
{
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
@@ -2675,7 +2675,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
-template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 2447494..bdf61b2 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -111,7 +111,7 @@ template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_s
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
- half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+ half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
diff --git a/csrc/ops.cu b/csrc/ops.cu
index c16dd96..ed32828 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -435,7 +435,7 @@ int fill_up_to_nearest_multiple(int value, int multiple)
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
-void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
{
int threads = 512;
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
@@ -447,7 +447,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
num_blocks = num_blocks*(tileCols/32);
assert(threads <= tilesize);
- kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
+ kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -465,7 +465,6 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
col_tiles = col_tiles > 0 ? col_tiles : 1;
int num_blocks = row_tiles * col_tiles;
-
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 4b09ecf..acfdb06 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -163,7 +163,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
-void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols);
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 7356c11..0707674 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -248,8 +248,8 @@ extern "C"
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
- void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
- { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); }
+ void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols)
+ { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); }
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }