From 451fd9506e215aa25643e9782cb7d8aed2a266cc Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 3 Aug 2022 11:54:01 -0700 Subject: Added fixes for the case that matmullt dim A is zero, e.g. [0, 768]. --- csrc/ops.cu | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) (limited to 'csrc') diff --git a/csrc/ops.cu b/csrc/ops.cu index b3d07c6..cfc9605 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); assert(threads <= tilesize); - //cout << num_blocks << " blocks" << endl; - kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r int tile_cols = STATS_THREADS*STATS_ITEMS; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - if(nnz_threshold == 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) @@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col int tile_rows = 16; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); - assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; - //cout << cols << " " << tiledCols << " " << tiledRows << endl; - //cout << "num blocks " << num_blocks << endl; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - //cout << A << " " << out_col_normed << endl; if(threshold > 0.0f) kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); else @@ -518,7 +520,12 @@ template void transformRowToFormat(char * A, char *o int tile_rows = 32; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); @@ -545,10 +552,6 @@ template void transformRowToFormat(char * A, char *o } } - //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl; - //cout << "num blocks " << num_blocks << endl; - - //cout << A << " " << out_col_normed << endl; kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -- cgit v1.2.3