From 9268dc9d887a3d54cd1f008dcb628aaa5b5bd90a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 19:30:37 -0700 Subject: Some progress on build script; added multi-cuda install script. --- csrc/kernels.cu | 17 ++++++++--------- csrc/kernels.cuh | 2 +- csrc/ops.cu | 22 +++++++++++++++++++++- 3 files changed, 30 insertions(+), 11 deletions(-) (limited to 'csrc') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4e744fb..6eca3aa 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2166,7 +2166,6 @@ template BlockExchange; - __shared__ typename BlockExchange::TempStorage temp_storage; // we load row after row from the base_position // Load data row by row @@ -2446,7 +2445,7 @@ template -__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block @@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o { for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]); + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; __syncthreads(); } @@ -2596,12 +2595,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o // TEMPLATE DEFINITIONS //============================================================== -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index cbfbeba..4e65e96 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -107,7 +107,7 @@ template __global__ void kPercentileCl __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, diff --git a/csrc/ops.cu b/csrc/ops.cu index 8946015..c430d55 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -247,6 +247,8 @@ int roundoff(int v, int d) { } +#ifdef NO_CUBLASLT +#else template cublasLtOrder_t get_order() { switch(ORDER) @@ -266,7 +268,11 @@ template cublasLtOrder_t get_order() case COL_AMPERE: return CUBLASLT_ORDER_COL32_2R_4R4; break; + default: + break; } + + return CUBLASLT_ORDER_ROW; } template cublasLtOrder_t get_order(); @@ -274,6 +280,7 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); +#endif template int get_leading_dim(int dim1, int dim2) @@ -297,6 +304,9 @@ template int get_leading_dim(int dim1, int dim2) // 32*32 tiles return 32*roundoff(dim1, 32); break; + default: + return 0; + break; } } @@ -306,7 +316,8 @@ template int get_leading_dim(int dim1, int dim2); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { - +#ifdef NO_CUBLASLT +#else cublasLtOrder_t orderA = get_order(); cublasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); @@ -345,6 +356,7 @@ template void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif } template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -358,6 +370,9 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { +#ifdef NO_CUBLASLT + return 0; +#else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -412,6 +427,7 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; +#endif } int fill_up_to_nearest_multiple(int value, int multiple) @@ -523,6 +539,9 @@ template void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { +#ifdef NO_CUBLASLT +#else + cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +#endif } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) -- cgit v1.2.3