summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e /csrc
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
Diffstat (limited to 'csrc')
-rw-r--r--csrc/kernels.cu874
-rw-r--r--csrc/kernels.cuh12
-rw-r--r--csrc/ops.cu406
-rw-r--r--csrc/ops.cuh104
-rw-r--r--csrc/pythonInterface.c127
5 files changed, 1522 insertions, 1 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index d0aabff..1c3e723 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
}
}
+template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
+{
+ // 0. reset stats to -FLT_MAX
+ // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
+ // 2. compute col max (per thread); store in smem due to register pressure
+ // 3. compute row max (per block); store in smem to accumulate full global mem transation
+ // 4. store data via atomicMax
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+ const int items_per_load = ITEMS_PER_THREAD*THREADS;
+
+ typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
+ typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
+ typedef cub::BlockReduce<int, THREADS> BlockRowSum;
+ typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
+
+ __shared__ union {
+ typename BlockExchange::TempStorage exchange;
+ typename BlockRowReduce::TempStorage rowreduce;
+ typename BlockRowSum::TempStorage rowsum;
+ typename LoadT::TempStorage loadt;
+ } temp_storage;
+
+ __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
+ __shared__ int smem_row_nnz_values[TILE_ROWS];
+ //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
+
+ half local_data[ITEMS_PER_THREAD];
+ float local_data_fp32[ITEMS_PER_THREAD];
+ float local_col_absmax_values[ITEMS_PER_THREAD];
+ int local_row_nnz_count = 0;
+ float row_absmax = -FLT_MAX;
+
+ // 0. reset stats to -FLT_MAX
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
+ smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
+ smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
+ }
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_col_absmax_values[j] = -FLT_MAX;
+
+ __syncthreads();
+
+ int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
+ int i = base_idx;
+ // we load row after row from the base_position
+ // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
+ for(int row = 0; row < TILE_ROWS; row++)
+ {
+ if(base_row+row >= rows){ break; }
+ local_row_nnz_count = 0;
+ i = base_idx + ((row)*cols);
+ // each thread gets data from the same column
+ __syncthreads();
+ LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data[j] = fabsf(local_data[j]);
+
+
+ if(SPARSE_DECOMP)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ if((float)local_data[j] >= nnz_threshold)
+ {
+ local_row_nnz_count += 1;
+ local_data[j] = 0.0f;
+ }
+ }
+
+ // 2. compute col max (per thread); store in smem due to register pressure
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ // take the col max for this row
+ // we use shared memory because register pressure is too high if we do this locally
+ //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
+ local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
+
+ // 3. compute row max (per block); store in smem to accumulate full global mem transation
+ __syncthreads();
+
+ // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data_fp32[j] = local_data[j];
+
+ row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
+ if(SPARSE_DECOMP)
+ {
+ __syncthreads();
+ local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
+ }
+ // we store the data temporarily in shared memory so we
+ // can execute a full atomic block transaction into global memory later
+ // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
+ if(threadIdx.x == 0)
+ {
+ smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
+ // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
+ smem_row_nnz_values[row] = local_row_nnz_count;
+ }
+
+ __syncthreads();
+
+ }
+
+ // 4. store data via atomicMax
+ // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
+ // into a striped arangement: [0, 8, 16, 24, ..] for t0
+ __syncthreads();
+ BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_col+threadIdx.x+(j*THREADS) < cols)
+ {
+ float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
+ if(val < local_col_absmax_values[j])
+ atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
+ }
+
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_row+threadIdx.x+(j*THREADS) < rows)
+ {
+ float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
+ if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
+ atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
+ }
+
+ if(SPARSE_DECOMP)
+ if(threadIdx.x < TILE_ROWS)
+ nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
+
+}
+
+template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+
+#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
+
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
+{
+
+ // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
+ // since different row/col stats need to be loaded with each thread.
+ // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
+ // and would lead to low global load utilization.
+ // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
+ // for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
+ // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
+ // This allows for efficient row/col loading from shared memory within the tile.
+ // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
+ // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
+ // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
+ // shared memory loads.
+
+ // data is in 32 column-tile major with tile width 32 columns and numRows rows
+ // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
+ // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
+ // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
+ // C2. Compute normalization values and store col values in register
+ // S1. Store C1 into 16-bit output
+ // S2. Store col/row statistics of new buffer in shared memory
+
+ // We allow for sub-tiles to span multiple col32 tiles. This is okay
+ // since the items per thread only rely on a single column statistic.
+
+
+ const int n_out = numRows*numCols;
+
+ int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
+ // we have tiles of size numRows*32, thus col only increases every numRows
+ // num_row_tiles is the tiles after which the column increases by 32
+ // blockIdx.x is the index of the current tile
+ int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
+ // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
+ int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
+
+ // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
+ // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
+ // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
+ // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
+ // 1024*1024/(128*32) = 256 tiles
+ // 256 tiles are 256*128*32/4 = 256*1024 threads
+
+ // 1. Figure out how index relates to the start of the sub-tile
+ // 2. Each thread < SUBTILE_ROWS calculates row index
+ // 3. Load striped and store in shared memory
+
+ int local_values[ITEMS_PER_THREAD];
+ half local_output[ITEMS_PER_THREAD];
+ float local_rowStats[ITEMS_PER_THREAD];
+ __shared__ float smem_rowStats[SUBTILE_ROWS];
+
+ typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
+ typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
+ __shared__ typename LoadInt32::TempStorage loadint32;
+ __shared__ typename ExchangeInt32::TempStorage exchangeint32;
+
+
+ // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
+ float colStat = col >= numCols ? 0.0f : colStats[col];
+ // no block loads for rows for now -- keep it simple
+ for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
+ {
+ // todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
+ int row = (base_row+j) % numRows; // wrap around
+ // each warp accesses the same element, for four consequitive elements
+ // todo: update description about striped shared memory, it is not needed
+ // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
+ smem_rowStats[j] = rowStats[row];
+ }
+ __syncthreads();
+
+
+ // each block processes SUBTILE_ROWS*32 elements
+ const int items_per_load = THREADS*ITEMS_PER_THREAD;
+ const int rows_per_load = items_per_load/32;
+
+ int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
+ int row_offset = 0;
+ // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
+ int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
+ for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
+ {
+ int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
+ int valid_items = valid_rows*32;
+ if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
+ break;
+
+ // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
+ LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
+ ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat);
+ //absmax_col = fmax(fabsf(local_output[j]), absmax_col);
+
+ // we store data in row major
+ // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
+ // so that each thread holds ITEMS_PER_THREAD consecutive items for each row
+ // this way throughput into storage is increased by a factor of ~2x
+ // for now we use a simple store
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
+ if(outIdx< n_out && col < numCols)
+ out[outIdx] = local_output[j];
+ }
+
+ row_offset += rows_per_load;
+ }
+}
+
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
+{
+ // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
+ // Each thread reads the same column but multiple rows
+ // Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
+
+ // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
+ // 1. Load data row by row (should be at least with TILE_SIZE = 512)
+ // 2. quantize data with row/col stats
+ // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+ const int items_per_load = ITEMS_PER_THREAD*THREADS;
+
+ typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
+ __shared__ typename LoadHalf::TempStorage loadhalf;
+ typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
+ __shared__ typename StoreInt8::TempStorage storeint8;
+
+ __shared__ float smem_row_stats[TILE_ROWS];
+ __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
+
+ half local_data[ITEMS_PER_THREAD];
+ float local_col_stats[ITEMS_PER_THREAD];
+ char local_quantized_data[ITEMS_PER_THREAD];
+
+ // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
+ local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
+
+ for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
+ {
+ if(base_row + i < rows)
+ smem_row_stats[i] = rowStats[base_row+i];
+
+ if(SPARSE_DECOMP)
+ smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
+ }
+ __syncthreads();
+
+ // we load row after row from the base_position
+ // 1. Load data row by row (should be at least with TILE_SIZE = 512)
+ for(int row = 0; row < TILE_ROWS; row++)
+ {
+ if(base_row + row >= rows){ break; }
+ int i = base_idx + (row*cols);
+ int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
+
+
+ LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
+ float row_stat = __fdividef(127.0f, smem_row_stats[row]);
+
+ // 2. quantize data with row/col stats
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ // we already pre-normalized the col/row stat:
+ // what this does is float/absmax*127 = int8
+ if(SPARSE_DECOMP)
+ {
+ if(fabsf((float)local_data[j]) >= threshold)
+ {
+ local_quantized_data[j] = 0;
+
+ int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
+
+ rowidx[old_idx] = base_row+row;
+ colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
+ val[old_idx] = local_data[j];
+ }
+ else
+ {
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
+ }
+ }
+ else
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
+ }
+
+ StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
+
+ // 2. quantize data with row/col stats
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ // we already pre-normalized the col/row stat:
+ // what this does is float/absmax*127 = int8
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
+ }
+
+ __syncthreads();
+ StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
+
+ }
+}
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
+{
+
+ // 0. Load data into 32*32 shared memory tiles
+ // 1. transpose / reorder in shared memory
+ // 2. store
+
+ // COL32 FORMAT:
+ // rows*32 tiles
+
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+
+ // AMPERE FORMAT:
+ // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
+ // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
+
+
+ // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
+ // As such we need:
+ // at least 32*4 shared memory tiles for col32; preferably 32*32
+ // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
+ // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
+ // for efficient loading of row major we need to load 128 elements and repeat this 32 items
+ // this would imply a 32x128 shared memory tile -> 4kb
+ // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
+ // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
+ // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
+ // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
+ //
+ // to make the shared memory work with that occupancy we might need to union the block loads/stores
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+
+ // we load 128 bytes per warp with
+ // 32 rows for transposes that fill col32 types
+ // so that we can have contiguous stores
+ __shared__ char smem_data[32*33*ITEMS_PER_THREAD];
+ char local_data[ITEMS_PER_THREAD];
+ typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
+ __shared__ typename BlockExchange::TempStorage temp_storage;
+
+ // we load row after row from the base_position
+ // Load data row by row
+ int warps = blockDim.x/32;
+ int warp_id = threadIdx.x/32;
+ int warp_lane = threadIdx.x % 32;
+ int offset = 0;
+
+ int smem_row = 0;
+ // each warp loads one row of 128 bytes
+ for(int row = warp_id; row < TILE_ROWS; row+=warps)
+ {
+ int i = base_idx + (row*cols);
+ // we load up to 128 bytes/items per load
+ int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
+
+ // 0. Load data into 32*32 shared memory tiles
+ if(base_row + row < rows)
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int col_idx = warp_lane+(j*32);
+ if(col_idx < valid_items)
+ local_data[j] = A[i+col_idx];
+ else
+ local_data[j] = 0;
+ }
+ }
+ else
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data[j] = 0;
+ }
+
+ if(TRANSPOSE)
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int local_col = (32*j)+warp_lane;
+ //int local_row = row;
+ // store as 256x32
+ smem_data[(local_col*33) + row] = local_data[j];
+ }
+ }
+ else
+ {
+ // treat smem as 32x256, that is 32 rows and 256 columns
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
+ }
+
+
+
+ smem_row += warps;
+
+ // 1. transpose / reorder in shared memory
+ if(smem_row % 32 == 0)
+ {
+ smem_row = 0;
+ __syncthreads();
+
+ for(int subrow = warp_id; subrow < 32; subrow+=warps)
+ {
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+
+ switch(FORMAT)
+ {
+ case COL32:
+ if(TRANSPOSE)
+ {
+ // data lies in shared memory in the following way:
+ // row0 [col0 col1 ... col31]
+ // row1 [col0 col1 ... col31]
+ // ...
+ //
+ // As such we read consequtive entries with 256 threads (8rows x 32 columns)
+ // as j increase, the row increase by a factor of 8
+ // We load 8 rows per subrow loop, and subrow increase by 8 per loop
+ // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size outRows*32 and base_row is done in increments of 32
+ offset = base_row*outRows;
+ out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ offset = (base_col/32)*(32*rows);
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+ out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
+ }
+ }
+ break;
+ case COL_TURING:
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+ //
+ // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
+ if(TRANSPOSE)
+ {
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size 8*32 = 256 elements offset
+ // for each row offset of 8 we increaes the tile first
+ // after all rows are exhausted, we increase the col
+ int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
+
+ // we increase by row_tile_column every 32 columns
+ // base_row increase in increments of 32
+ //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
+ //int col_offset = (base_row/32)*row_tile_column;
+ // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
+ // 256*outRows/8*base_row/32 = outRows*base_row
+ int col_offset = outRows*base_row;
+
+ offset = row_offset+col_offset;
+
+ // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
+ // odd or even rows with the warp_id (each warp processes one row)
+ // the col is warp_lane (max 32 columns per row) and the row warp_id
+ if(warp_id % 2 == 1)
+ // odd
+ offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
+ else
+ // even
+ offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
+
+ out[offset] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+ // set offset designates the tile offset among the 8*32 tiles
+ // we first increase rows and then columns. Since we load 128 columns at once
+ // we increase the offset by outRows*32 every 32 columns
+ // additionally, we increase the offset by 8*32=256 every 8 rows
+ offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
+ // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
+ // each of these has 32 values in total for 32*4 = 128 as offset if odd
+ // every set of 4 columns increases the total offset by 16
+ // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
+ // this happends every 8 rows anew (subrow % 8)
+ // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
+ int subcol = warp_lane;
+
+ // add local offset (4x4 sub-tile)
+ if(subrow % 2 == 1)
+ // odd
+ offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
+ else
+ // even
+ offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
+
+ out[offset] = data;
+ }
+ }
+ break;
+ case COL_AMPERE:
+ // AMPERE FORMAT:
+ // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
+ // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
+ if(TRANSPOSE)
+ {
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size 32*32 = 1024 elements offset
+ // for each row offset of 32 we increaes the tile first
+ // after all rows are exhausted, we increase the col
+ int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
+
+ // we increase by row_tile_column every 32 columns
+ // base_row increase in increments of 32
+ //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
+ //int col_offset = (base_row/32)*row_tile_column;
+ // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
+ // 1024*outRows/32*base_row/32 = outRows*base_row
+ int col_offset = outRows*base_row;
+
+ offset = row_offset+col_offset;
+
+
+ // same as in the non-transpose case (see below)
+ // the difference is that now rows = cols
+ // in this case warp_id = subrow
+
+ // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
+ // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
+ // every 2 rows, the offset increases by two [0, 1, 8, 9...]
+ // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
+ int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
+ int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
+
+ // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
+ out[offset + (ampere_row*32) + warp_lane] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+
+ // set offset designates the tile offset among the 32*32 tiles
+ // we first increase rows and then columns. Since we load 128 columns at once
+ // we increase the offset by outRows*32 every 32 columns
+ // additionally, we increase the offset by 32*32=1024 every 32 rows
+ offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
+
+ // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
+ // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
+ // every 2 rows, the offset increases by two [0, 1, 8, 9...]
+ // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
+ int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
+
+ // global offset + row with 32 cols each + 32 cols per j + col_idx
+ out[offset + (local_row*32) + warp_lane] = data;
+ }
+ }
+ break;
+ }
+ }
+ }
+ }
+ }
+}
+
+#define C 1.0f/127.0f
+#define MAX_SPARSE_COUNT 32
+#define SMEM_SIZE 8*256
+template <typename T, int SPMM_ITEMS, int BITS>
+__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
+{
+
+ // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
+ // If a block finishes, the next one is scheduled. Since the last blocks like have fewer
+ // elements they finish faster "fillin up" the gaps left by larger blocks
+
+ // without tensor cores
+ // 1. use rowidx_length to find what to load (as many blocks as there are rows)
+ // 2. Load A into registers
+ // 3. each warp loads all required rows of B but each warp is offset by k
+ // 4. Do mma operations that accumulate into registers
+ // 5. Each warp stores its output row into matrix C
+
+ const int count = max_count[blockIdx.x];
+ const int local_max_idx = max_idx[blockIdx.x];
+ const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
+ const int local_row_idx = rowidx[offset];
+
+ const int warp_id = threadIdx.x / 32;
+ const int warp_idx = threadIdx.x % 32;
+ const int warp_offset = (warp_id*32)*SPMM_ITEMS;
+ const int num_items = BITS == 8 ? 8 : 8;
+ int idx_col_B = warp_offset;
+ int local_idx_col_B_offset = 0;
+
+ half local_valA[MAX_SPARSE_COUNT];
+ int local_colidxA[MAX_SPARSE_COUNT];
+ half local_valC[SPMM_ITEMS];
+ T local_valsB[num_items];
+ half local_valOut[num_items];
+ // 128 byte loads per warp == 4 bytes per thread
+
+ // 2. Load A into registers
+ for(int j = 0; j < MAX_SPARSE_COUNT; j++)
+ {
+ local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
+ local_colidxA[j] = j < count ? colidx[offset+j] : 0;
+ }
+
+ // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
+ // we expect each warp to be SPMM_ITEMS*32 apart
+ // we have a total of 128 bytes for the bank with a bank size of 4 bytes
+ // added 3 bytes = 6 values between warps should reduce bank conflicts
+ __shared__ half smem_dequant_stats[SMEM_SIZE];
+
+
+ while(idx_col_B < colsB)
+ {
+
+ if(dequant_stats != NULL)
+ {
+ for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
+ if((idx_col_B+i-local_idx_col_B_offset) < colsB)
+ smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
+
+ __syncthreads();
+ }
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j++)
+ local_valC[j] = 0.0f;
+
+ #pragma unroll
+ for(int i = 0; i < count; i++)
+ {
+ // 3. each warp loads all required rows of B but each warp is offset by k
+ int row_offset = colsB*local_colidxA[i];
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j+=num_items)
+ {
+ // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
+ int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
+ if(idx >= colsB){ break; }
+ //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
+ if((idx+num_items < colsB))
+ {
+ if(BITS == 8)
+ reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
+ else
+ reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
+ }
+ else
+ {
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ if(idx+k < colsB)
+ local_valsB[k] = B[row_offset+idx+k];
+ else
+ local_valsB[k] = 0.0f;
+ }
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ {
+ //if((float)local_valsB[k] != 0.0)
+ // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
+ if(BITS == 8 && dequant_stats != NULL)
+ // we do texture cache reads (__ldg) on dequant_stats which should be super fast
+ {
+ float valB = local_valsB[k];
+ float valA = local_valA[i];
+ if(valB != 0.0 && valA != 0.0)
+ local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
+ }
+ else
+ local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
+ }
+ }
+ }
+
+ int idx_row_C = (colsB*local_row_idx);
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j+=num_items)
+ {
+ //int idx_col_C = idx_col_B + (32*j) + warp_idx;
+ int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
+ int idx_val = idx_col_C + idx_row_C;
+
+ if(idx_col_C +num_items < colsB)
+ {
+
+ // load outputs to do inplace addition
+ reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
+
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
+
+ reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
+ }
+ else
+ {
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ if(idx_col_C + k < colsB)
+ out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
+ }
+ }
+
+ idx_col_B += blockDim.x*SPMM_ITEMS;
+ local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
+ }
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+
+template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+
+template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 0a3676c..cbfbeba 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
+
+template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
+ int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
+ half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+
+template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+
#endif
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 40c185c..8946015 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -8,6 +8,7 @@
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
+#include <cassert>
#include <common.h>
@@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
+void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
+{
+ const int falpha = 1;
+ const int fbeta = 0;
+ const void * alpha = &falpha;
+ const void * beta = &fbeta;
+ cublasStatus_t status;
+
+ status = cublasGemmEx(context->m_handle,
+ transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
+ transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
+ m, n, k,
+ alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
+ C, CUDA_R_32I, ldc,
+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+
+ if (status != CUBLAS_STATUS_SUCCESS)
+ {
+ std::cout << "CUBLAS ERROR: Status " << status << std::endl;
+ }
+
+}
+
+void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long long int strideA, long long int strideB, long long int strideC, int batchCount)
+{
+ const int falpha = 1;
+ const int fbeta = 0;
+ const void * alpha = &falpha;
+ const void * beta = &fbeta;
+ cublasStatus_t status;
+
+ //cout << transposeA << transposeB << endl;
+ //printf("%i %i %i\n", m,n,k);
+ //printf("%i %i %i\n", lda,ldb,ldc);
+ //printf("%i %i %i\n", strideA, strideB, strideC);
+ //printf("%i\n", batchCount);
+
+ status = cublasGemmStridedBatchedEx(context->m_handle,
+ transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
+ transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
+ m, n, k,
+ alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
+ C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
+
+ if (status != CUBLAS_STATUS_SUCCESS)
+ {
+ std::cout << "CUBLAS ERROR: Status " << status << std::endl;
+ }
+
+}
+
+int roundoff(int v, int d) {
+ return (v + d - 1) / d * d;
+}
+
+
+template<int ORDER> cublasLtOrder_t get_order()
+{
+ switch(ORDER)
+ {
+ case ROW:
+ return CUBLASLT_ORDER_ROW;
+ break;
+ case COL:
+ return CUBLASLT_ORDER_COL;
+ break;
+ case COL32:
+ return CUBLASLT_ORDER_COL32;
+ break;
+ case COL_TURING:
+ return CUBLASLT_ORDER_COL4_4R2_8C;
+ break;
+ case COL_AMPERE:
+ return CUBLASLT_ORDER_COL32_2R_4R4;
+ break;
+ }
+}
+
+template cublasLtOrder_t get_order<ROW>();
+template cublasLtOrder_t get_order<COL>();
+template cublasLtOrder_t get_order<COL32>();
+template cublasLtOrder_t get_order<COL_TURING>();
+template cublasLtOrder_t get_order<COL_AMPERE>();
+
+
+template<int ORDER> int get_leading_dim(int dim1, int dim2)
+{
+ switch(ORDER)
+ {
+ case ROW:
+ return dim2;
+ break;
+ case COL:
+ return dim1;
+ break;
+ case COL32:
+ // 32*row tiles
+ return dim1*32;
+ break;
+ case COL_TURING:
+ return 32*roundoff(dim1, 8);
+ break;
+ case COL_AMPERE:
+ // 32*32 tiles
+ return 32*roundoff(dim1, 32);
+ break;
+ }
+}
+
+template int get_leading_dim<ROW>(int dim1, int dim2);
+template int get_leading_dim<COL>(int dim1, int dim2);
+template int get_leading_dim<COL32>(int dim1, int dim2);
+
+template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
+{
+
+ cublasLtOrder_t orderA = get_order<SRC>();
+ cublasLtOrder_t orderOut = get_order<TARGET>();
+ int ldA = get_leading_dim<SRC>(dim1, dim2);
+ int ldOut = get_leading_dim<TARGET>(dim1, dim2);
+
+ cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
+ cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
+ cublasOperation_t opTranspose = CUBLAS_OP_T;
+ float transformAlpha = 1.0f, transformBeta = 0.0f;
+
+
+ if(DTYPE == 8)
+ {
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
+ }
+ else if(DTYPE == 32)
+ {
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
+ }
+ else
+ {
+ printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
+ }
+
+ checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
+ checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
+
+ checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
+
+ if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
+
+ checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
+
+ if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
+ if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
+ if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
+}
+
+template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
+
+template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+{
+ int has_error = 0;
+ cublasLtMatmulDesc_t matmulDesc = NULL;
+ cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
+ cublasOperation_t opT = CUBLAS_OP_T;
+ cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(FORMATB == COL_TURING)
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
+ else
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
+
+ if(DTYPE_OUT == 32)
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ int alpha = 1, beta = 0;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(!SCALE_ROWS)
+ {
+ float alpha = 1.0f, beta = 0.0f;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ }
+
+
+ if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
+ if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
+ if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
+ if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
+ if(has_error == 1)
+ printf("error detected");
+
+ return has_error;
+}
+
+int fill_up_to_nearest_multiple(int value, int multiple)
+{
+ return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
+}
+
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+{
+ int threads = 512;
+ int tileCols = fill_up_to_nearest_multiple(numCols, 32);
+ int n = numRows*tileCols;
+ int subtile_rows = 128;
+ int tilesize = 32*subtile_rows;
+ int num_blocks = numRows/subtile_rows;
+ num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
+ num_blocks = num_blocks*(tileCols/32);
+ assert(threads <= tilesize);
+
+ //cout << num_blocks << " blocks" << endl;
+
+ kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+#define STATS_THREADS 64
+#define STATS_ITEMS 4
+#define STATS_ROWS 16
+void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
+{
+ int tile_cols = STATS_THREADS*STATS_ITEMS;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+
+ if(nnz_threshold == 0.0)
+ kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
+ else if(nnz_threshold != 0.0)
+ kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+
+}
+
+void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
+{
+ int threads = 64;
+ int items_per_thread = 4;
+ int tile_cols = threads*items_per_thread;
+ int tile_rows = 16;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+
+ //cout << cols << " " << tiledCols << " " << tiledRows << endl;
+ //cout << "num blocks " << num_blocks << endl;
+
+ //cout << A << " " << out_col_normed << endl;
+ if(threshold > 0.0f)
+ kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
+ else
+ kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
+
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
+{
+ int threads = 256;
+ int items_per_thread = 8;
+ // we load 128 column values per warp
+ int tile_cols = 32*items_per_thread;
+ int tile_rows = 32;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int outCols = fill_up_to_nearest_multiple(cols, 32);
+ int outRows = fill_up_to_nearest_multiple(rows, 32);
+ if(FORMAT == COL_TURING)
+ {
+ if(TRANSPOSE)
+ outRows = fill_up_to_nearest_multiple(cols, 8);
+ else
+ outRows = fill_up_to_nearest_multiple(rows, 8);
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ if(TRANSPOSE)
+ outRows = fill_up_to_nearest_multiple(cols, 32);
+ else
+ outRows = fill_up_to_nearest_multiple(rows, 32);
+ }
+ else
+ {
+ if(TRANSPOSE)
+ {
+ outCols = fill_up_to_nearest_multiple(rows, 32);
+ outRows = cols;
+ }
+ }
+
+ //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
+ //cout << "num blocks " << num_blocks << endl;
+
+ //cout << A << " " << out_col_normed << endl;
+ kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
+{
+
+ cusparseSpMatDescr_t descA;
+ cusparseDnMatDescr_t descB, descC;
+
+ float alpha = 1.0f;
+ float beta = 0.0f;
+ void *dBuffer = NULL;
+ size_t bufferSize = 0;
+
+ CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
+ A_rowidx, A_colidx, A_vals,
+ CUSPARSE_INDEX_32I,
+ CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
+ // Create dense matrix C
+ CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
+ CUDA_R_16F, CUSPARSE_ORDER_ROW) );
+ // Create dense matrix B
+ if(transposed_B)
+ {
+ int tmp = A_cols;
+ A_cols = B_cols;
+ B_cols = tmp;
+ }
+
+ CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
+ CUDA_R_16F, CUSPARSE_ORDER_ROW) );
+ // allocate an external buffer if needed
+ CHECK_CUSPARSE( cusparseSpMM_bufferSize(
+ handle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE,
+ transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
+ &alpha, descA, descB, &beta, descC, CUDA_R_32F,
+ CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
+ CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
+
+ // execute SpMM
+ CHECK_CUSPARSE( cusparseSpMM(handle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE,
+ transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
+ &alpha, descA, descB, &beta, descC, CUDA_R_32F,
+ CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
+
+ // destroy matrix/vector descriptors
+ CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
+ CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
+ CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
+ CUDA_CHECK_RETURN( cudaFree(dBuffer) );
+}
+
+template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{
+
+ kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+
+template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+
+template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
+
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 8fb4cec..4e719df 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -14,6 +14,11 @@
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
+#include <cublas_v2.h>
+#include <cublasLt.h>
+#include <cusparse.h>
+#include <vector>
+#include <functional>
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
@@ -25,6 +30,34 @@
#define THREADS_PER_BLOCKS (512)
+#define CHECK_CUSPARSE(value) { \
+ cusparseStatus_t _m_cudaStat = value; \
+ if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
+ fprintf(stderr, "Error %s at line %d in file %s\n", \
+ cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
+ exit(1); \
+ } }
+
+
+#define THREADS_PER_BLOCKS (512)
+
+
+inline void checkCudaStatus(cudaError_t status) {
+ if (status != cudaSuccess) {
+ printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
+ throw std::logic_error("cuda API failed");
+ }
+}
+
+inline int checkCublasStatus(cublasStatus_t status) {
+ if (status != CUBLAS_STATUS_SUCCESS) {
+ printf("cuBLAS API failed with status %d\n", status);
+ //throw std::logic_error("cuBLAS API failed");
+ return 1;
+ }
+ return 0;
+}
+
typedef enum Operations_t
{
ksmul = 0,
@@ -39,6 +72,57 @@ typedef enum Optimizer_t
ADAGRAD = 4,
} Optimizer_t;
+typedef enum Transform_t
+{
+ ROW = 0,
+ COL = 1,
+ COL32 = 2,
+ COL_TURING = 3,
+ COL_AMPERE = 4,
+} Transform_t;
+
+class Context
+{
+ public:
+ cublasHandle_t m_handle;
+
+ Context()
+ {
+ cublasHandle_t handle;
+ cublasCreate_v2(&handle);
+ m_handle = handle;
+ }
+
+};
+
+class ContextLt
+{
+ public:
+ cublasLtHandle_t m_handle;
+
+ ContextLt()
+ {
+ cublasLtHandle_t handle;
+ cublasLtCreate(&handle);
+ m_handle = handle;
+ }
+
+};
+
+class ContextCusparse
+{
+ public:
+ cusparseHandle_t m_handle;
+
+ ContextCusparse()
+ {
+ cusparseHandle_t handle;
+ cusparseCreate(&handle);
+ m_handle = handle;
+ }
+
+};
+
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
@@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
+void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
+void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long long int strideA, long long int strideB, long long int strideC, int batchCount);
+
+
+template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+
+template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
+void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
+void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
+void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
+ int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
+
+template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
+
+void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
+
+template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+
#endif
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index c2fed6b..03c8d92 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#endif
+#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
+void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
+{ \
+ transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
+} \
+
+MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
+MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
+MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
+MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
+MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
+MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
+MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
+MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
+
+void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
+void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
+void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
+void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
+void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
+void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
+
+ int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_TURING, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_TURING, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_AMPERE, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_AMPERE, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_AMPERE, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
+void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
extern "C"
{
#if BUILD_CUDA
@@ -155,7 +201,86 @@ extern "C"
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
- #endif
+ void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
+ { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
+ void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long strideA, long strideB, long strideC, int batchCount)
+ { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
+
+ Context *get_context(){ return new Context(); }
+ ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
+
+ int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+ //{ (cublasLtHandle_t)context->m_handle; return 0; }
+ //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
+ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
+ { \
+ transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
+ } \
+
+ MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
+ MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
+ MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
+ MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
+
+ void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+ { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); }
+ void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
+ { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
+
+ void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
+ { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }
+
+ void ctransform_row2col32(char * A, char *out, int rows, int cols)
+ { transform_row2col32(A, out, rows, cols); }
+
+ void ctransform_row2col32T(char * A, char *out, int rows, int cols)
+ { transform_row2col32T(A, out, rows, cols); }
+
+ void ctransform_row2turing(char * A, char *out, int rows, int cols)
+ { transform_row2turing(A, out, rows, cols); }
+
+ void ctransform_row2turingT(char * A, char *out, int rows, int cols)
+ { transform_row2turingT(A, out, rows, cols); }
+
+ void ctransform_row2ampere(char * A, char *out, int rows, int cols)
+ { transform_row2ampere(A, out, rows, cols); }
+
+ void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
+ { transform_row2ampereT(A, out, rows, cols); }
+
+ void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
+ { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
+
+ void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+ { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
+ void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+ { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
+#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
}