summaryrefslogtreecommitdiff
path: root/csrc/ops.cuh
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e /csrc/ops.cuh
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
Diffstat (limited to 'csrc/ops.cuh')
-rw-r--r--csrc/ops.cuh104
1 files changed, 104 insertions, 0 deletions
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 8fb4cec..4e719df 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -14,6 +14,11 @@
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
+#include <cublas_v2.h>
+#include <cublasLt.h>
+#include <cusparse.h>
+#include <vector>
+#include <functional>
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
@@ -25,6 +30,34 @@
#define THREADS_PER_BLOCKS (512)
+#define CHECK_CUSPARSE(value) { \
+ cusparseStatus_t _m_cudaStat = value; \
+ if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
+ fprintf(stderr, "Error %s at line %d in file %s\n", \
+ cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
+ exit(1); \
+ } }
+
+
+#define THREADS_PER_BLOCKS (512)
+
+
+inline void checkCudaStatus(cudaError_t status) {
+ if (status != cudaSuccess) {
+ printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
+ throw std::logic_error("cuda API failed");
+ }
+}
+
+inline int checkCublasStatus(cublasStatus_t status) {
+ if (status != CUBLAS_STATUS_SUCCESS) {
+ printf("cuBLAS API failed with status %d\n", status);
+ //throw std::logic_error("cuBLAS API failed");
+ return 1;
+ }
+ return 0;
+}
+
typedef enum Operations_t
{
ksmul = 0,
@@ -39,6 +72,57 @@ typedef enum Optimizer_t
ADAGRAD = 4,
} Optimizer_t;
+typedef enum Transform_t
+{
+ ROW = 0,
+ COL = 1,
+ COL32 = 2,
+ COL_TURING = 3,
+ COL_AMPERE = 4,
+} Transform_t;
+
+class Context
+{
+ public:
+ cublasHandle_t m_handle;
+
+ Context()
+ {
+ cublasHandle_t handle;
+ cublasCreate_v2(&handle);
+ m_handle = handle;
+ }
+
+};
+
+class ContextLt
+{
+ public:
+ cublasLtHandle_t m_handle;
+
+ ContextLt()
+ {
+ cublasLtHandle_t handle;
+ cublasLtCreate(&handle);
+ m_handle = handle;
+ }
+
+};
+
+class ContextCusparse
+{
+ public:
+ cusparseHandle_t m_handle;
+
+ ContextCusparse()
+ {
+ cusparseHandle_t handle;
+ cusparseCreate(&handle);
+ m_handle = handle;
+ }
+
+};
+
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
@@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
+void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
+void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long long int strideA, long long int strideB, long long int strideC, int batchCount);
+
+
+template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+
+template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
+void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
+void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
+void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
+ int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
+
+template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
+
+void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
+
+template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+
#endif