summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu17
1 files changed, 8 insertions, 9 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 4e744fb..6eca3aa 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
- __shared__ typename BlockExchange::TempStorage temp_storage;
// we load row after row from the base_position
// Load data row by row
@@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
-__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
+__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
@@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
- smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
+ smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
__syncthreads();
}
@@ -2596,12 +2595,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
// TEMPLATE DEFINITIONS
//==============================================================
-template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
-template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);