From c771b3a75a6ebbfbfc398a028a477246b0799cf0 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 14:41:05 -0700 Subject: Most tests passing. --- csrc/ops.cuh | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) (limited to 'csrc/ops.cuh') diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8fb4cec..4e719df 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -14,6 +14,11 @@ #include #include +#include +#include +#include +#include +#include #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ @@ -25,6 +30,34 @@ #define THREADS_PER_BLOCKS (512) +#define CHECK_CUSPARSE(value) { \ + cusparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + +#define THREADS_PER_BLOCKS (512) + + +inline void checkCudaStatus(cudaError_t status) { + if (status != cudaSuccess) { + printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); + throw std::logic_error("cuda API failed"); + } +} + +inline int checkCublasStatus(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + printf("cuBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + typedef enum Operations_t { ksmul = 0, @@ -39,6 +72,57 @@ typedef enum Optimizer_t ADAGRAD = 4, } Optimizer_t; +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +class Context +{ + public: + cublasHandle_t m_handle; + + Context() + { + cublasHandle_t handle; + cublasCreate_v2(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + cublasLtHandle_t m_handle; + + ContextLt() + { + cublasLtHandle_t handle; + cublasLtCreate(&handle); + m_handle = handle; + } + +}; + +class ContextCusparse +{ + public: + cusparseHandle_t m_handle; + + ContextCusparse() + { + cusparseHandle_t handle; + cusparseCreate(&handle); + m_handle = handle; + } + +}; + template void estimateQuantiles(T *A, float *code, float offset, int n); @@ -70,4 +154,24 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols); +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, + int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + #endif -- cgit v1.2.3