From c771b3a75a6ebbfbfc398a028a477246b0799cf0 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 14:41:05 -0700 Subject: Most tests passing. --- csrc/kernels.cu | 874 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 874 insertions(+) (limited to 'csrc/kernels.cu') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d0aabff..1c3e723 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char } } +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockReduce BlockRowReduce; + typedef cub::BlockReduce BlockRowSum; + typedef cub::BlockExchange BlockExchange; + + __shared__ union { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + } temp_storage; + + __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; + __shared__ int smem_row_nnz_values[TILE_ROWS]; + //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS]; + + half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + __syncthreads(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + __syncthreads(); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = fabsf(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + __syncthreads(); + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); + if(SPARSE_DECOMP) + { + __syncthreads(); + local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(threadIdx.x == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + __syncthreads(); + + } + + // 4. store data via atomicMax + // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arangement: [0, 8, 16, 24, ..] for t0 + __syncthreads(); + BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+threadIdx.x+(j*THREADS) < cols) + { + float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+threadIdx.x+(j*THREADS) < rows) + { + float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; + if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) + atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(threadIdx.x < TILE_ROWS) + nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; + +} + +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + __shared__ float smem_rowStats[SUBTILE_ROWS]; + + typedef cub::BlockLoad LoadInt32; + typedef cub::BlockExchange ExchangeInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + __shared__ typename ExchangeInt32::TempStorage exchangeint32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + float colStat = col >= numCols ? 0.0f : colStats[col]; + // no block loads for rows for now -- keep it simple + for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + } + __syncthreads(); + + + // each block processes SUBTILE_ROWS*32 elements + const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat); + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + } +} + + +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadHalf; + __shared__ typename LoadHalf::TempStorage loadhalf; + typedef cub::BlockStore StoreInt8; + __shared__ typename StoreInt8::TempStorage storeint8; + + __shared__ float smem_row_stats[TILE_ROWS]; + __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); + + for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; + } + __syncthreads(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = __fdividef(127.0f, smem_row_stats[row]); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(fabsf((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + } + else + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); + } + + __syncthreads(); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + + } +} + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; + char local_data[ITEMS_PER_THREAD]; + typedef cub::BlockExchange BlockExchange; + __shared__ typename BlockExchange::TempStorage temp_storage; + + // we load row after row from the base_position + // Load data row by row + int warps = blockDim.x/32; + int warp_id = threadIdx.x/32; + int warp_lane = threadIdx.x % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + __syncthreads(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consequtive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happends every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define C 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]); + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx); + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + //if((float)local_valsB[k] != 0.0) + // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB); + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -- cgit v1.2.3