summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu874
1 files changed, 874 insertions, 0 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index d0aabff..1c3e723 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
}
}
+template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
+{
+ // 0. reset stats to -FLT_MAX
+ // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
+ // 2. compute col max (per thread); store in smem due to register pressure
+ // 3. compute row max (per block); store in smem to accumulate full global mem transation
+ // 4. store data via atomicMax
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+ const int items_per_load = ITEMS_PER_THREAD*THREADS;
+
+ typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
+ typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
+ typedef cub::BlockReduce<int, THREADS> BlockRowSum;
+ typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
+
+ __shared__ union {
+ typename BlockExchange::TempStorage exchange;
+ typename BlockRowReduce::TempStorage rowreduce;
+ typename BlockRowSum::TempStorage rowsum;
+ typename LoadT::TempStorage loadt;
+ } temp_storage;
+
+ __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
+ __shared__ int smem_row_nnz_values[TILE_ROWS];
+ //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
+
+ half local_data[ITEMS_PER_THREAD];
+ float local_data_fp32[ITEMS_PER_THREAD];
+ float local_col_absmax_values[ITEMS_PER_THREAD];
+ int local_row_nnz_count = 0;
+ float row_absmax = -FLT_MAX;
+
+ // 0. reset stats to -FLT_MAX
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
+ smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
+ smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
+ }
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_col_absmax_values[j] = -FLT_MAX;
+
+ __syncthreads();
+
+ int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
+ int i = base_idx;
+ // we load row after row from the base_position
+ // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
+ for(int row = 0; row < TILE_ROWS; row++)
+ {
+ if(base_row+row >= rows){ break; }
+ local_row_nnz_count = 0;
+ i = base_idx + ((row)*cols);
+ // each thread gets data from the same column
+ __syncthreads();
+ LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data[j] = fabsf(local_data[j]);
+
+
+ if(SPARSE_DECOMP)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ if((float)local_data[j] >= nnz_threshold)
+ {
+ local_row_nnz_count += 1;
+ local_data[j] = 0.0f;
+ }
+ }
+
+ // 2. compute col max (per thread); store in smem due to register pressure
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ // take the col max for this row
+ // we use shared memory because register pressure is too high if we do this locally
+ //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
+ local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
+
+ // 3. compute row max (per block); store in smem to accumulate full global mem transation
+ __syncthreads();
+
+ // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data_fp32[j] = local_data[j];
+
+ row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
+ if(SPARSE_DECOMP)
+ {
+ __syncthreads();
+ local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
+ }
+ // we store the data temporarily in shared memory so we
+ // can execute a full atomic block transaction into global memory later
+ // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
+ if(threadIdx.x == 0)
+ {
+ smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
+ // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
+ smem_row_nnz_values[row] = local_row_nnz_count;
+ }
+
+ __syncthreads();
+
+ }
+
+ // 4. store data via atomicMax
+ // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
+ // into a striped arangement: [0, 8, 16, 24, ..] for t0
+ __syncthreads();
+ BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_col+threadIdx.x+(j*THREADS) < cols)
+ {
+ float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
+ if(val < local_col_absmax_values[j])
+ atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
+ }
+
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_row+threadIdx.x+(j*THREADS) < rows)
+ {
+ float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
+ if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
+ atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
+ }
+
+ if(SPARSE_DECOMP)
+ if(threadIdx.x < TILE_ROWS)
+ nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
+
+}
+
+template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+
+#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
+
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
+{
+
+ // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
+ // since different row/col stats need to be loaded with each thread.
+ // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
+ // and would lead to low global load utilization.
+ // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
+ // for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
+ // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
+ // This allows for efficient row/col loading from shared memory within the tile.
+ // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
+ // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
+ // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
+ // shared memory loads.
+
+ // data is in 32 column-tile major with tile width 32 columns and numRows rows
+ // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
+ // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
+ // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
+ // C2. Compute normalization values and store col values in register
+ // S1. Store C1 into 16-bit output
+ // S2. Store col/row statistics of new buffer in shared memory
+
+ // We allow for sub-tiles to span multiple col32 tiles. This is okay
+ // since the items per thread only rely on a single column statistic.
+
+
+ const int n_out = numRows*numCols;
+
+ int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
+ // we have tiles of size numRows*32, thus col only increases every numRows
+ // num_row_tiles is the tiles after which the column increases by 32
+ // blockIdx.x is the index of the current tile
+ int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
+ // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
+ int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
+
+ // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
+ // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
+ // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
+ // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
+ // 1024*1024/(128*32) = 256 tiles
+ // 256 tiles are 256*128*32/4 = 256*1024 threads
+
+ // 1. Figure out how index relates to the start of the sub-tile
+ // 2. Each thread < SUBTILE_ROWS calculates row index
+ // 3. Load striped and store in shared memory
+
+ int local_values[ITEMS_PER_THREAD];
+ half local_output[ITEMS_PER_THREAD];
+ float local_rowStats[ITEMS_PER_THREAD];
+ __shared__ float smem_rowStats[SUBTILE_ROWS];
+
+ typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
+ typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
+ __shared__ typename LoadInt32::TempStorage loadint32;
+ __shared__ typename ExchangeInt32::TempStorage exchangeint32;
+
+
+ // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
+ float colStat = col >= numCols ? 0.0f : colStats[col];
+ // no block loads for rows for now -- keep it simple
+ for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
+ {
+ // todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
+ int row = (base_row+j) % numRows; // wrap around
+ // each warp accesses the same element, for four consequitive elements
+ // todo: update description about striped shared memory, it is not needed
+ // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
+ smem_rowStats[j] = rowStats[row];
+ }
+ __syncthreads();
+
+
+ // each block processes SUBTILE_ROWS*32 elements
+ const int items_per_load = THREADS*ITEMS_PER_THREAD;
+ const int rows_per_load = items_per_load/32;
+
+ int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
+ int row_offset = 0;
+ // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
+ int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
+ for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
+ {
+ int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
+ int valid_items = valid_rows*32;
+ if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
+ break;
+
+ // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
+ LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
+ ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat);
+ //absmax_col = fmax(fabsf(local_output[j]), absmax_col);
+
+ // we store data in row major
+ // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
+ // so that each thread holds ITEMS_PER_THREAD consecutive items for each row
+ // this way throughput into storage is increased by a factor of ~2x
+ // for now we use a simple store
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
+ if(outIdx< n_out && col < numCols)
+ out[outIdx] = local_output[j];
+ }
+
+ row_offset += rows_per_load;
+ }
+}
+
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
+{
+ // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
+ // Each thread reads the same column but multiple rows
+ // Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
+
+ // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
+ // 1. Load data row by row (should be at least with TILE_SIZE = 512)
+ // 2. quantize data with row/col stats
+ // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+ const int items_per_load = ITEMS_PER_THREAD*THREADS;
+
+ typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
+ __shared__ typename LoadHalf::TempStorage loadhalf;
+ typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
+ __shared__ typename StoreInt8::TempStorage storeint8;
+
+ __shared__ float smem_row_stats[TILE_ROWS];
+ __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
+
+ half local_data[ITEMS_PER_THREAD];
+ float local_col_stats[ITEMS_PER_THREAD];
+ char local_quantized_data[ITEMS_PER_THREAD];
+
+ // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
+ local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
+
+ for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
+ {
+ if(base_row + i < rows)
+ smem_row_stats[i] = rowStats[base_row+i];
+
+ if(SPARSE_DECOMP)
+ smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
+ }
+ __syncthreads();
+
+ // we load row after row from the base_position
+ // 1. Load data row by row (should be at least with TILE_SIZE = 512)
+ for(int row = 0; row < TILE_ROWS; row++)
+ {
+ if(base_row + row >= rows){ break; }
+ int i = base_idx + (row*cols);
+ int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
+
+
+ LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
+ float row_stat = __fdividef(127.0f, smem_row_stats[row]);
+
+ // 2. quantize data with row/col stats
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ // we already pre-normalized the col/row stat:
+ // what this does is float/absmax*127 = int8
+ if(SPARSE_DECOMP)
+ {
+ if(fabsf((float)local_data[j]) >= threshold)
+ {
+ local_quantized_data[j] = 0;
+
+ int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
+
+ rowidx[old_idx] = base_row+row;
+ colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
+ val[old_idx] = local_data[j];
+ }
+ else
+ {
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
+ }
+ }
+ else
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
+ }
+
+ StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
+
+ // 2. quantize data with row/col stats
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ // we already pre-normalized the col/row stat:
+ // what this does is float/absmax*127 = int8
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
+ }
+
+ __syncthreads();
+ StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
+
+ }
+}
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
+{
+
+ // 0. Load data into 32*32 shared memory tiles
+ // 1. transpose / reorder in shared memory
+ // 2. store
+
+ // COL32 FORMAT:
+ // rows*32 tiles
+
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+
+ // AMPERE FORMAT:
+ // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
+ // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
+
+
+ // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
+ // As such we need:
+ // at least 32*4 shared memory tiles for col32; preferably 32*32
+ // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
+ // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
+ // for efficient loading of row major we need to load 128 elements and repeat this 32 items
+ // this would imply a 32x128 shared memory tile -> 4kb
+ // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
+ // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
+ // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
+ // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
+ //
+ // to make the shared memory work with that occupancy we might need to union the block loads/stores
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+
+ // we load 128 bytes per warp with
+ // 32 rows for transposes that fill col32 types
+ // so that we can have contiguous stores
+ __shared__ char smem_data[32*33*ITEMS_PER_THREAD];
+ char local_data[ITEMS_PER_THREAD];
+ typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
+ __shared__ typename BlockExchange::TempStorage temp_storage;
+
+ // we load row after row from the base_position
+ // Load data row by row
+ int warps = blockDim.x/32;
+ int warp_id = threadIdx.x/32;
+ int warp_lane = threadIdx.x % 32;
+ int offset = 0;
+
+ int smem_row = 0;
+ // each warp loads one row of 128 bytes
+ for(int row = warp_id; row < TILE_ROWS; row+=warps)
+ {
+ int i = base_idx + (row*cols);
+ // we load up to 128 bytes/items per load
+ int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
+
+ // 0. Load data into 32*32 shared memory tiles
+ if(base_row + row < rows)
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int col_idx = warp_lane+(j*32);
+ if(col_idx < valid_items)
+ local_data[j] = A[i+col_idx];
+ else
+ local_data[j] = 0;
+ }
+ }
+ else
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data[j] = 0;
+ }
+
+ if(TRANSPOSE)
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int local_col = (32*j)+warp_lane;
+ //int local_row = row;
+ // store as 256x32
+ smem_data[(local_col*33) + row] = local_data[j];
+ }
+ }
+ else
+ {
+ // treat smem as 32x256, that is 32 rows and 256 columns
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
+ }
+
+
+
+ smem_row += warps;
+
+ // 1. transpose / reorder in shared memory
+ if(smem_row % 32 == 0)
+ {
+ smem_row = 0;
+ __syncthreads();
+
+ for(int subrow = warp_id; subrow < 32; subrow+=warps)
+ {
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+
+ switch(FORMAT)
+ {
+ case COL32:
+ if(TRANSPOSE)
+ {
+ // data lies in shared memory in the following way:
+ // row0 [col0 col1 ... col31]
+ // row1 [col0 col1 ... col31]
+ // ...
+ //
+ // As such we read consequtive entries with 256 threads (8rows x 32 columns)
+ // as j increase, the row increase by a factor of 8
+ // We load 8 rows per subrow loop, and subrow increase by 8 per loop
+ // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size outRows*32 and base_row is done in increments of 32
+ offset = base_row*outRows;
+ out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ offset = (base_col/32)*(32*rows);
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+ out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
+ }
+ }
+ break;
+ case COL_TURING:
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+ //
+ // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
+ if(TRANSPOSE)
+ {
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size 8*32 = 256 elements offset
+ // for each row offset of 8 we increaes the tile first
+ // after all rows are exhausted, we increase the col
+ int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
+
+ // we increase by row_tile_column every 32 columns
+ // base_row increase in increments of 32
+ //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
+ //int col_offset = (base_row/32)*row_tile_column;
+ // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
+ // 256*outRows/8*base_row/32 = outRows*base_row
+ int col_offset = outRows*base_row;
+
+ offset = row_offset+col_offset;
+
+ // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
+ // odd or even rows with the warp_id (each warp processes one row)
+ // the col is warp_lane (max 32 columns per row) and the row warp_id
+ if(warp_id % 2 == 1)
+ // odd
+ offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
+ else
+ // even
+ offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
+
+ out[offset] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+ // set offset designates the tile offset among the 8*32 tiles
+ // we first increase rows and then columns. Since we load 128 columns at once
+ // we increase the offset by outRows*32 every 32 columns
+ // additionally, we increase the offset by 8*32=256 every 8 rows
+ offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
+ // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
+ // each of these has 32 values in total for 32*4 = 128 as offset if odd
+ // every set of 4 columns increases the total offset by 16
+ // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
+ // this happends every 8 rows anew (subrow % 8)
+ // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
+ int subcol = warp_lane;
+
+ // add local offset (4x4 sub-tile)
+ if(subrow % 2 == 1)
+ // odd
+ offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
+ else
+ // even
+ offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
+
+ out[offset] = data;
+ }
+ }
+ break;
+ case COL_AMPERE:
+ // AMPERE FORMAT:
+ // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
+ // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
+ if(TRANSPOSE)
+ {
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size 32*32 = 1024 elements offset
+ // for each row offset of 32 we increaes the tile first
+ // after all rows are exhausted, we increase the col
+ int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
+
+ // we increase by row_tile_column every 32 columns
+ // base_row increase in increments of 32
+ //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
+ //int col_offset = (base_row/32)*row_tile_column;
+ // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
+ // 1024*outRows/32*base_row/32 = outRows*base_row
+ int col_offset = outRows*base_row;
+
+ offset = row_offset+col_offset;
+
+
+ // same as in the non-transpose case (see below)
+ // the difference is that now rows = cols
+ // in this case warp_id = subrow
+
+ // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
+ // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
+ // every 2 rows, the offset increases by two [0, 1, 8, 9...]
+ // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
+ int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
+ int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
+
+ // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
+ out[offset + (ampere_row*32) + warp_lane] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+
+ // set offset designates the tile offset among the 32*32 tiles
+ // we first increase rows and then columns. Since we load 128 columns at once
+ // we increase the offset by outRows*32 every 32 columns
+ // additionally, we increase the offset by 32*32=1024 every 32 rows
+ offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
+
+ // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
+ // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
+ // every 2 rows, the offset increases by two [0, 1, 8, 9...]
+ // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
+ int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
+
+ // global offset + row with 32 cols each + 32 cols per j + col_idx
+ out[offset + (local_row*32) + warp_lane] = data;
+ }
+ }
+ break;
+ }
+ }
+ }
+ }
+ }
+}
+
+#define C 1.0f/127.0f
+#define MAX_SPARSE_COUNT 32
+#define SMEM_SIZE 8*256
+template <typename T, int SPMM_ITEMS, int BITS>
+__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
+{
+
+ // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
+ // If a block finishes, the next one is scheduled. Since the last blocks like have fewer
+ // elements they finish faster "fillin up" the gaps left by larger blocks
+
+ // without tensor cores
+ // 1. use rowidx_length to find what to load (as many blocks as there are rows)
+ // 2. Load A into registers
+ // 3. each warp loads all required rows of B but each warp is offset by k
+ // 4. Do mma operations that accumulate into registers
+ // 5. Each warp stores its output row into matrix C
+
+ const int count = max_count[blockIdx.x];
+ const int local_max_idx = max_idx[blockIdx.x];
+ const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
+ const int local_row_idx = rowidx[offset];
+
+ const int warp_id = threadIdx.x / 32;
+ const int warp_idx = threadIdx.x % 32;
+ const int warp_offset = (warp_id*32)*SPMM_ITEMS;
+ const int num_items = BITS == 8 ? 8 : 8;
+ int idx_col_B = warp_offset;
+ int local_idx_col_B_offset = 0;
+
+ half local_valA[MAX_SPARSE_COUNT];
+ int local_colidxA[MAX_SPARSE_COUNT];
+ half local_valC[SPMM_ITEMS];
+ T local_valsB[num_items];
+ half local_valOut[num_items];
+ // 128 byte loads per warp == 4 bytes per thread
+
+ // 2. Load A into registers
+ for(int j = 0; j < MAX_SPARSE_COUNT; j++)
+ {
+ local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
+ local_colidxA[j] = j < count ? colidx[offset+j] : 0;
+ }
+
+ // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
+ // we expect each warp to be SPMM_ITEMS*32 apart
+ // we have a total of 128 bytes for the bank with a bank size of 4 bytes
+ // added 3 bytes = 6 values between warps should reduce bank conflicts
+ __shared__ half smem_dequant_stats[SMEM_SIZE];
+
+
+ while(idx_col_B < colsB)
+ {
+
+ if(dequant_stats != NULL)
+ {
+ for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
+ if((idx_col_B+i-local_idx_col_B_offset) < colsB)
+ smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
+
+ __syncthreads();
+ }
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j++)
+ local_valC[j] = 0.0f;
+
+ #pragma unroll
+ for(int i = 0; i < count; i++)
+ {
+ // 3. each warp loads all required rows of B but each warp is offset by k
+ int row_offset = colsB*local_colidxA[i];
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j+=num_items)
+ {
+ // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
+ int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
+ if(idx >= colsB){ break; }
+ //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
+ if((idx+num_items < colsB))
+ {
+ if(BITS == 8)
+ reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
+ else
+ reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
+ }
+ else
+ {
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ if(idx+k < colsB)
+ local_valsB[k] = B[row_offset+idx+k];
+ else
+ local_valsB[k] = 0.0f;
+ }
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ {
+ //if((float)local_valsB[k] != 0.0)
+ // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
+ if(BITS == 8 && dequant_stats != NULL)
+ // we do texture cache reads (__ldg) on dequant_stats which should be super fast
+ {
+ float valB = local_valsB[k];
+ float valA = local_valA[i];
+ if(valB != 0.0 && valA != 0.0)
+ local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
+ }
+ else
+ local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
+ }
+ }
+ }
+
+ int idx_row_C = (colsB*local_row_idx);
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j+=num_items)
+ {
+ //int idx_col_C = idx_col_B + (32*j) + warp_idx;
+ int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
+ int idx_val = idx_col_C + idx_row_C;
+
+ if(idx_col_C +num_items < colsB)
+ {
+
+ // load outputs to do inplace addition
+ reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
+
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
+
+ reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
+ }
+ else
+ {
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ if(idx_col_C + k < colsB)
+ out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
+ }
+ }
+
+ idx_col_B += blockDim.x*SPMM_ITEMS;
+ local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
+ }
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+
+template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+
+template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);