diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 19:38:01 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 19:38:01 -0700 |
commit | 5737f2b027a1e0ec8540a3aa914632d44ad9c62d (patch) | |
tree | b288c905eaba75dc6b43a8bcebc82720c16e4816 /csrc | |
parent | 47a73d94c3d3284f6073b0ff189ed5bc9e3a8762 (diff) | |
parent | dc8c9efdb33130f960adc864916b67d0cb744dbb (diff) |
Merge branch 'patch_merge' into extract_outliers
Diffstat (limited to 'csrc')
-rw-r--r-- | csrc/kernels.cu | 9 | ||||
-rw-r--r-- | csrc/kernels.cuh | 2 | ||||
-rw-r--r-- | csrc/ops.cu | 22 | ||||
-rw-r--r-- | csrc/pythonInterface.c | 6 |
4 files changed, 29 insertions, 10 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 79ad5de..d4eb56c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_ __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; __shared__ int smem_row_nnz_values[TILE_ROWS]; - //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS]; half local_data[ITEMS_PER_THREAD]; float local_data_fp32[ITEMS_PER_THREAD]; @@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_ local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); // 3. compute row max (per block); store in smem to accumulate full global mem transation - __syncthreads(); // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_data_fp32[j] = local_data[j]; + __syncthreads(); + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); if(SPARSE_DECOMP) { @@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; char local_data[ITEMS_PER_THREAD]; typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange; - __shared__ typename BlockExchange::TempStorage temp_storage; // we load row after row from the base_position // Load data row by row @@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 template <typename T, int SPMM_ITEMS, int BITS> -__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block @@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o { for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]); + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; __syncthreads(); } diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index eda5ba0..2447494 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -107,7 +107,7 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); -template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, diff --git a/csrc/ops.cu b/csrc/ops.cu index e6227ae..952894c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -247,6 +247,8 @@ int roundoff(int v, int d) { } +#ifdef NO_CUBLASLT +#else template<int ORDER> cublasLtOrder_t get_order() { switch(ORDER) @@ -266,7 +268,11 @@ template<int ORDER> cublasLtOrder_t get_order() case COL_AMPERE: return CUBLASLT_ORDER_COL32_2R_4R4; break; + default: + break; } + + return CUBLASLT_ORDER_ROW; } template cublasLtOrder_t get_order<ROW>(); @@ -274,6 +280,7 @@ template cublasLtOrder_t get_order<COL>(); template cublasLtOrder_t get_order<COL32>(); template cublasLtOrder_t get_order<COL_TURING>(); template cublasLtOrder_t get_order<COL_AMPERE>(); +#endif template<int ORDER> int get_leading_dim(int dim1, int dim2) @@ -297,6 +304,9 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2) // 32*32 tiles return 32*roundoff(dim1, 32); break; + default: + return 0; + break; } } @@ -306,7 +316,8 @@ template int get_leading_dim<COL32>(int dim1, int dim2); template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { - +#ifdef NO_CUBLASLT +#else cublasLtOrder_t orderA = get_order<SRC>(); cublasLtOrder_t orderOut = get_order<TARGET>(); int ldA = get_leading_dim<SRC>(dim1, dim2); @@ -345,6 +356,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif } template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -358,6 +370,9 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { +#ifdef NO_CUBLASLT + return 0; +#else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -412,6 +427,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle printf("error detected"); return has_error; +#endif } int fill_up_to_nearest_multiple(int value, int multiple) @@ -523,6 +539,9 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { +#ifdef NO_CUBLASLT +#else + cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +#endif } template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 2ecbaae..7356c11 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -82,7 +82,6 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); } -#endif #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -132,10 +131,11 @@ void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_r void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } +#endif extern "C" { - #if BUILD_CUDA +#if BUILD_CUDA void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } @@ -231,7 +231,7 @@ extern "C" { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ |