summaryrefslogtreecommitdiff
path: root/csrc/kernels.cu
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-26 18:15:51 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-26 18:15:51 -0700
commit32fa459ed7c812c79e847145004061f21b7ac0d9 (patch)
tree1a11f128f5db119afc76e6f2d649b20f34536a74 /csrc/kernels.cu
parentbcab99ec877ba063543bd7c03ba1cdd1b06e8078 (diff)
Added col_ampere outlier extraction kernel.
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r--csrc/kernels.cu40
1 files changed, 22 insertions, 18 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index bb36d9b..79ad5de 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -2626,28 +2626,32 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
offset += tile_offset_rows + tile_offset_cols;
-
- char val = 0;
- //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
- if(offset > tiledColsA*tiledRowsA)
- printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx);
- else
- val = A[offset];
+ char val = A[offset];
int out_idx = (row*idx_size) + blockIdx.x;
-
- //if(out_idx > colsA*idx_size)
- if(val != 0)
- {
- //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val);
- out[out_idx] = val;
- }
- else
- {
- out[out_idx] = val;
- }
+ out[out_idx] = val;
}
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
+ {
+ // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
+ // within each tile.
+ int offset_per_col_tile = ((rowsA+31)/32)*32*32;
+ int tile_offset_rows = (row/32)*32*32;
+ int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
+ int subtile_col_idx = local_colidx%32;
+ int subtile_row_idx = row % 32;
+ // this magic is taken from the cublasLt doc (search for COL32)
+ int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
+ offset += tile_offset_cols + tile_offset_rows;
+
+ char val = A[offset];
+ int out_idx = (row*idx_size) + blockIdx.x;
+ out[out_idx] = val;
+ }
}
}