summaryrefslogtreecommitdiff
path: root/csrc/ops.cu
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e /csrc/ops.cu
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r--csrc/ops.cu406
1 files changed, 406 insertions, 0 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 40c185c..8946015 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -8,6 +8,7 @@
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
+#include <cassert>
#include <common.h>
@@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
+void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
+{
+ const int falpha = 1;
+ const int fbeta = 0;
+ const void * alpha = &falpha;
+ const void * beta = &fbeta;
+ cublasStatus_t status;
+
+ status = cublasGemmEx(context->m_handle,
+ transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
+ transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
+ m, n, k,
+ alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
+ C, CUDA_R_32I, ldc,
+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+
+ if (status != CUBLAS_STATUS_SUCCESS)
+ {
+ std::cout << "CUBLAS ERROR: Status " << status << std::endl;
+ }
+
+}
+
+void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long long int strideA, long long int strideB, long long int strideC, int batchCount)
+{
+ const int falpha = 1;
+ const int fbeta = 0;
+ const void * alpha = &falpha;
+ const void * beta = &fbeta;
+ cublasStatus_t status;
+
+ //cout << transposeA << transposeB << endl;
+ //printf("%i %i %i\n", m,n,k);
+ //printf("%i %i %i\n", lda,ldb,ldc);
+ //printf("%i %i %i\n", strideA, strideB, strideC);
+ //printf("%i\n", batchCount);
+
+ status = cublasGemmStridedBatchedEx(context->m_handle,
+ transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
+ transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
+ m, n, k,
+ alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
+ C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
+
+ if (status != CUBLAS_STATUS_SUCCESS)
+ {
+ std::cout << "CUBLAS ERROR: Status " << status << std::endl;
+ }
+
+}
+
+int roundoff(int v, int d) {
+ return (v + d - 1) / d * d;
+}
+
+
+template<int ORDER> cublasLtOrder_t get_order()
+{
+ switch(ORDER)
+ {
+ case ROW:
+ return CUBLASLT_ORDER_ROW;
+ break;
+ case COL:
+ return CUBLASLT_ORDER_COL;
+ break;
+ case COL32:
+ return CUBLASLT_ORDER_COL32;
+ break;
+ case COL_TURING:
+ return CUBLASLT_ORDER_COL4_4R2_8C;
+ break;
+ case COL_AMPERE:
+ return CUBLASLT_ORDER_COL32_2R_4R4;
+ break;
+ }
+}
+
+template cublasLtOrder_t get_order<ROW>();
+template cublasLtOrder_t get_order<COL>();
+template cublasLtOrder_t get_order<COL32>();
+template cublasLtOrder_t get_order<COL_TURING>();
+template cublasLtOrder_t get_order<COL_AMPERE>();
+
+
+template<int ORDER> int get_leading_dim(int dim1, int dim2)
+{
+ switch(ORDER)
+ {
+ case ROW:
+ return dim2;
+ break;
+ case COL:
+ return dim1;
+ break;
+ case COL32:
+ // 32*row tiles
+ return dim1*32;
+ break;
+ case COL_TURING:
+ return 32*roundoff(dim1, 8);
+ break;
+ case COL_AMPERE:
+ // 32*32 tiles
+ return 32*roundoff(dim1, 32);
+ break;
+ }
+}
+
+template int get_leading_dim<ROW>(int dim1, int dim2);
+template int get_leading_dim<COL>(int dim1, int dim2);
+template int get_leading_dim<COL32>(int dim1, int dim2);
+
+template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
+{
+
+ cublasLtOrder_t orderA = get_order<SRC>();
+ cublasLtOrder_t orderOut = get_order<TARGET>();
+ int ldA = get_leading_dim<SRC>(dim1, dim2);
+ int ldOut = get_leading_dim<TARGET>(dim1, dim2);
+
+ cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
+ cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
+ cublasOperation_t opTranspose = CUBLAS_OP_T;
+ float transformAlpha = 1.0f, transformBeta = 0.0f;
+
+
+ if(DTYPE == 8)
+ {
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
+ }
+ else if(DTYPE == 32)
+ {
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
+ }
+ else
+ {
+ printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
+ }
+
+ checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
+ checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
+
+ checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
+
+ if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
+
+ checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
+
+ if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
+ if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
+ if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
+}
+
+template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
+
+template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+{
+ int has_error = 0;
+ cublasLtMatmulDesc_t matmulDesc = NULL;
+ cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
+ cublasOperation_t opT = CUBLAS_OP_T;
+ cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(FORMATB == COL_TURING)
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
+ else
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
+
+ if(DTYPE_OUT == 32)
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ int alpha = 1, beta = 0;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(!SCALE_ROWS)
+ {
+ float alpha = 1.0f, beta = 0.0f;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ }
+
+
+ if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
+ if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
+ if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
+ if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
+ if(has_error == 1)
+ printf("error detected");
+
+ return has_error;
+}
+
+int fill_up_to_nearest_multiple(int value, int multiple)
+{
+ return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
+}
+
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+{
+ int threads = 512;
+ int tileCols = fill_up_to_nearest_multiple(numCols, 32);
+ int n = numRows*tileCols;
+ int subtile_rows = 128;
+ int tilesize = 32*subtile_rows;
+ int num_blocks = numRows/subtile_rows;
+ num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
+ num_blocks = num_blocks*(tileCols/32);
+ assert(threads <= tilesize);
+
+ //cout << num_blocks << " blocks" << endl;
+
+ kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+#define STATS_THREADS 64
+#define STATS_ITEMS 4
+#define STATS_ROWS 16
+void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
+{
+ int tile_cols = STATS_THREADS*STATS_ITEMS;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+
+ if(nnz_threshold == 0.0)
+ kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
+ else if(nnz_threshold != 0.0)
+ kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+
+}
+
+void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
+{
+ int threads = 64;
+ int items_per_thread = 4;
+ int tile_cols = threads*items_per_thread;
+ int tile_rows = 16;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+
+ //cout << cols << " " << tiledCols << " " << tiledRows << endl;
+ //cout << "num blocks " << num_blocks << endl;
+
+ //cout << A << " " << out_col_normed << endl;
+ if(threshold > 0.0f)
+ kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
+ else
+ kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
+
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
+{
+ int threads = 256;
+ int items_per_thread = 8;
+ // we load 128 column values per warp
+ int tile_cols = 32*items_per_thread;
+ int tile_rows = 32;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int outCols = fill_up_to_nearest_multiple(cols, 32);
+ int outRows = fill_up_to_nearest_multiple(rows, 32);
+ if(FORMAT == COL_TURING)
+ {
+ if(TRANSPOSE)
+ outRows = fill_up_to_nearest_multiple(cols, 8);
+ else
+ outRows = fill_up_to_nearest_multiple(rows, 8);
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ if(TRANSPOSE)
+ outRows = fill_up_to_nearest_multiple(cols, 32);
+ else
+ outRows = fill_up_to_nearest_multiple(rows, 32);
+ }
+ else
+ {
+ if(TRANSPOSE)
+ {
+ outCols = fill_up_to_nearest_multiple(rows, 32);
+ outRows = cols;
+ }
+ }
+
+ //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
+ //cout << "num blocks " << num_blocks << endl;
+
+ //cout << A << " " << out_col_normed << endl;
+ kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
+{
+
+ cusparseSpMatDescr_t descA;
+ cusparseDnMatDescr_t descB, descC;
+
+ float alpha = 1.0f;
+ float beta = 0.0f;
+ void *dBuffer = NULL;
+ size_t bufferSize = 0;
+
+ CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
+ A_rowidx, A_colidx, A_vals,
+ CUSPARSE_INDEX_32I,
+ CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
+ // Create dense matrix C
+ CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
+ CUDA_R_16F, CUSPARSE_ORDER_ROW) );
+ // Create dense matrix B
+ if(transposed_B)
+ {
+ int tmp = A_cols;
+ A_cols = B_cols;
+ B_cols = tmp;
+ }
+
+ CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
+ CUDA_R_16F, CUSPARSE_ORDER_ROW) );
+ // allocate an external buffer if needed
+ CHECK_CUSPARSE( cusparseSpMM_bufferSize(
+ handle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE,
+ transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
+ &alpha, descA, descB, &beta, descC, CUDA_R_32F,
+ CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
+ CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
+
+ // execute SpMM
+ CHECK_CUSPARSE( cusparseSpMM(handle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE,
+ transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
+ &alpha, descA, descB, &beta, descC, CUDA_R_32F,
+ CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
+
+ // destroy matrix/vector descriptors
+ CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
+ CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
+ CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
+ CUDA_CHECK_RETURN( cudaFree(dBuffer) );
+}
+
+template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{
+
+ kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+
+template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+
+template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
+
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);