diff options
-rw-r--r-- | csrc/kernels.cu | 40 | ||||
-rw-r--r-- | tests/test_functional.py | 14 |
2 files changed, 32 insertions, 22 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu index bb36d9b..79ad5de 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2626,28 +2626,32 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char * offset += tile_offset_rows + tile_offset_cols; - - char val = 0; - //printf("(%i (%i %i) (%i %i))\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); - if(offset > tiledColsA*tiledRowsA) - printf("(%i (%i %i) (%i %i)\n", offset, tile_offset_rows, tile_offset_cols, row, local_colidx); - else - val = A[offset]; + char val = A[offset]; int out_idx = (row*idx_size) + blockIdx.x; - - //if(out_idx > colsA*idx_size) - if(val != 0) - { - //printf("(%i %i) = (%i) = %i\n", row, local_colidx, out_idx, (int) val); - out[out_idx] = val; - } - else - { - out[out_idx] = val; - } + out[out_idx] = val; } + } + else if(FORMAT == COL_AMPERE) + { + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } } } diff --git a/tests/test_functional.py b/tests/test_functional.py index 4d06447..2d58fac 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1859,9 +1859,9 @@ def test_zp(): def test_extract_outliers(): for i in range(k): - shapeA = (4096, 4*4096) + shapeA = (4096, 4096*4) idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - #idx = torch.Tensor([32]).int().cuda() + #idx = torch.Tensor([0]).int().cuda() A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) outliers1 = A[:, idx.long()] @@ -1872,7 +1872,13 @@ def test_extract_outliers(): assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[1] == idx.numel() - #print(outliers1) - #print(outliers2) + torch.testing.assert_allclose(outliers1, outliers2) + + CA, SA = F.transform(A, 'col_ampere') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() torch.testing.assert_allclose(outliers1, outliers2) |