diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 19:38:01 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 19:38:01 -0700 |
commit | 5737f2b027a1e0ec8540a3aa914632d44ad9c62d (patch) | |
tree | b288c905eaba75dc6b43a8bcebc82720c16e4816 /csrc/ops.cu | |
parent | 47a73d94c3d3284f6073b0ff189ed5bc9e3a8762 (diff) | |
parent | dc8c9efdb33130f960adc864916b67d0cb744dbb (diff) |
Merge branch 'patch_merge' into extract_outliers
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r-- | csrc/ops.cu | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu index e6227ae..952894c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -247,6 +247,8 @@ int roundoff(int v, int d) { } +#ifdef NO_CUBLASLT +#else template<int ORDER> cublasLtOrder_t get_order() { switch(ORDER) @@ -266,7 +268,11 @@ template<int ORDER> cublasLtOrder_t get_order() case COL_AMPERE: return CUBLASLT_ORDER_COL32_2R_4R4; break; + default: + break; } + + return CUBLASLT_ORDER_ROW; } template cublasLtOrder_t get_order<ROW>(); @@ -274,6 +280,7 @@ template cublasLtOrder_t get_order<COL>(); template cublasLtOrder_t get_order<COL32>(); template cublasLtOrder_t get_order<COL_TURING>(); template cublasLtOrder_t get_order<COL_AMPERE>(); +#endif template<int ORDER> int get_leading_dim(int dim1, int dim2) @@ -297,6 +304,9 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2) // 32*32 tiles return 32*roundoff(dim1, 32); break; + default: + return 0; + break; } } @@ -306,7 +316,8 @@ template int get_leading_dim<COL32>(int dim1, int dim2); template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { - +#ifdef NO_CUBLASLT +#else cublasLtOrder_t orderA = get_order<SRC>(); cublasLtOrder_t orderOut = get_order<TARGET>(); int ldA = get_leading_dim<SRC>(dim1, dim2); @@ -345,6 +356,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif } template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -358,6 +370,9 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { +#ifdef NO_CUBLASLT + return 0; +#else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -412,6 +427,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle printf("error detected"); return has_error; +#endif } int fill_up_to_nearest_multiple(int value, int multiple) @@ -523,6 +539,9 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { +#ifdef NO_CUBLASLT +#else + cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +#endif } template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) |