summaryrefslogtreecommitdiff
path: root/csrc
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-03 11:54:01 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-03 11:54:01 -0700
commit451fd9506e215aa25643e9782cb7d8aed2a266cc (patch)
treea95aac44018b664dcae503918bb551728f8147c3 /csrc
parent2f01865a2ff4ad3345c156f7a2f76fe79ec4ed9a (diff)
Added fixes for the case that matmullt dim A is zero, e.g. [0, 768].
Diffstat (limited to 'csrc')
-rw-r--r--csrc/ops.cu31
1 files changed, 17 insertions, 14 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index b3d07c6..cfc9605 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -459,8 +459,6 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
assert(threads <= tilesize);
- //cout << num_blocks << " blocks" << endl;
-
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
@@ -473,11 +471,14 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+ int row_tiles = (tiledRows/STATS_ROWS);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
-
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
@@ -494,13 +495,14 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
- assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
+ int row_tiles = (tiledRows/tile_rows);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
- //cout << cols << " " << tiledCols << " " << tiledRows << endl;
- //cout << "num blocks " << num_blocks << endl;
+ assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
- //cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
@@ -518,7 +520,12 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
- int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int row_tiles = (tiledRows/tile_rows);
+ int col_tiles = (tiledCols/tile_cols);
+ row_tiles = row_tiles > 0 ? row_tiles : 1;
+ col_tiles = col_tiles > 0 ? col_tiles : 1;
+ int num_blocks = row_tiles * col_tiles;
+
assert(num_blocks <= 65535 && "CUDA ERROR: Maximum number of blocks for kernel exceeded");
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
@@ -545,10 +552,6 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
}
}
- //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
- //cout << "num blocks " << num_blocks << endl;
-
- //cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}