diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-03 11:54:01 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-03 11:54:01 -0700 |
commit | 451fd9506e215aa25643e9782cb7d8aed2a266cc (patch) | |
tree | a95aac44018b664dcae503918bb551728f8147c3 /csrc | |
parent | 2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a (diff) |
Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
Diffstat (limited to 'csrc')
-rw-r--r-- | csrc/ops.cu | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu index b3d07c6..cfc9605 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); assert(threads <= tilesize); - //cout << num_blocks << " blocks" << endl; - kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r int tile_cols = STATS_THREADS*STATS_ITEMS; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - if(nnz_threshold == 0.0) kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) @@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col int tile_rows = 16; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); - assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; - //cout << cols << " " << tiledCols << " " << tiledRows << endl; - //cout << "num blocks " << num_blocks << endl; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); - //cout << A << " " << out_col_normed << endl; if(threshold > 0.0f) kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); else @@ -518,7 +520,12 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o int tile_rows = 32; int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); - int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded"); int outCols = fill_up_to_nearest_multiple(cols, 32); int outRows = fill_up_to_nearest_multiple(rows, 32); @@ -545,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o } } - //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl; - //cout << "num blocks " << num_blocks << endl; - - //cout << A << " " << out_col_normed << endl; kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } |