From cbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 12:12:38 -0700 Subject: Boilerplate and test for extract_outliers. --- bitsandbytes/functional.py | 26 ++++++++++++++++++++++++++ csrc/kernels.cu | 7 +++++++ csrc/kernels.cuh | 2 ++ csrc/ops.cu | 27 +++++++++++++++++++++++++++ csrc/ops.cuh | 2 ++ csrc/pythonInterface.c | 6 ++++++ tests/test_functional.py | 17 +++++++++++++++++ 7 files changed, 87 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..a9233e2 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1409,3 +1409,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x *= SA[1]/127 x +=offset return x.to(dtype) + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ['col_turing', 'col_ampere'] + assert A.device.type == 'cuda' + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == 'col_ampere': + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + return out + + + + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1c3e723..78170d0 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2592,10 +2592,17 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o } } +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index cbfbeba..ec2068e 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -118,6 +118,8 @@ template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 8946015..fe2d7fe 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -578,10 +578,37 @@ template void spmm_coo_very_sparse_naive(int *max_count, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int elements = idx_size*cols; // matrix A is transposed, so we extract columns + int num_blocks = (elements+threads-1)/threads; + + if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + } + + kExtractOutliers<<>>(A, idx, out, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 4e719df..4b09ecf 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -174,4 +174,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 03c8d92..2ecbaae 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -106,6 +106,9 @@ void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRo void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } +void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } + int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } @@ -280,6 +283,9 @@ extern "C" void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } + void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 6cbe58f..b508367 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1856,3 +1856,20 @@ def test_zp(): print(err1, err2, err3, err4, err5, err6) + +def test_extract_outliers(): + shapeA = (128, 128) + idx = torch.randint(0, shapeA[1], size=(10,)).int() + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + outliers1 = A[:, idx.long()] + + CA, SA = F.transform(A, 'col_turing') + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + + + torch.testing.assert_allclose(outliers1, outliers2) -- cgit v1.2.3