summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 15:21:37 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 15:21:37 -0700
commit7d2ecd30c044840ba5f161ec73e5eaf30ac8131d (patch)
treefa76a8513df9c088478870226048fbf32e9e0d5d /csrc
parentc771b3a75a6ebbfbfc398a028a477246b0799cf0 (diff)
Fixed rowcol synchronization bug.
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu4
1 files changed, 2 insertions, 2 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 1c3e723..4e744fb 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS];
- //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD];
@@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
- __syncthreads();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];
+ __syncthreads();
+
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
if(SPARSE_DECOMP)
{