summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-26 19:38:01 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-26 19:38:01 -0700
commit5737f2b027a1e0ec8540a3aa914632d44ad9c62d (patch)
treeb288c905eaba75dc6b43a8bcebc82720c16e4816 /csrc/kernels.cu
parent47a73d94c3d3284f6073b0ff189ed5bc9e3a8762 (diff)
parentdc8c9efdb33130f960adc864916b67d0cb744dbb (diff)
Merge branch 'patch_merge' into extract_outliers
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu9
1 files changed, 4 insertions, 5 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 79ad5de..d4eb56c 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS];
- //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD];
@@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
- __syncthreads();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];
+ __syncthreads();
+
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
if(SPARSE_DECOMP)
{
@@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
- __shared__ typename BlockExchange::TempStorage temp_storage;
// we load row after row from the base_position
// Load data row by row
@@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
-__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
+__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
@@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
- smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
+ smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
__syncthreads();
}