summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-26 19:38:01 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-26 19:38:01 -0700
commit5737f2b027a1e0ec8540a3aa914632d44ad9c62d (patch)
treeb288c905eaba75dc6b43a8bcebc82720c16e4816
parent47a73d94c3d3284f6073b0ff189ed5bc9e3a8762 (diff)
parentdc8c9efdb33130f960adc864916b67d0cb744dbb (diff)
Merge branch 'patch_merge' into extract_outliers
-rw-r--r--CHANGELOG.md14
-rw-r--r--Makefile37
-rw-r--r--bitsandbytes/functional.py13
-rw-r--r--csrc/kernels.cu9
-rw-r--r--csrc/kernels.cuh2
-rw-r--r--csrc/ops.cu22
-rw-r--r--csrc/pythonInterface.c6
-rw-r--r--cuda_install.sh77
-rw-r--r--deploy_from_slurm.sh267
-rw-r--r--quicktest.py90
-rw-r--r--setup.py7
-rw-r--r--tests/test_functional.py3
12 files changed, 469 insertions, 78 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index fa20b15..285984e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -53,3 +53,17 @@ Bug fixes:
Docs:
- Added instructions how to solve "\_\_fatbinwrap_" errors.
+
+
+### 0.30.0
+
+#### 8-bit Inference Update
+
+Features:
+ - Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt)
+ - Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation
+ - Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations
+ - CPU only build now available (Thank you, @mryab)
+
+Deprecated:
+ - Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available
diff --git a/Makefile b/Makefile
index b58e233..10f267a 100644
--- a/Makefile
+++ b/Makefile
@@ -16,7 +16,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
-LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
+LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
@@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
-COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
@@ -43,31 +42,49 @@ CC_CUDA11x := -gencode arch=compute_75,code=sm_75
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
+CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
+CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
+
+CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
+CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
+CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
+
+
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
-cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
+ $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
-cuda110: $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+cuda110_nomatmul: $(BUILD_DIR) env
+ $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
-cuda11x: $(BUILD_DIR) env
- $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+cuda11x_nomatmul: $(BUILD_DIR) env
+ $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+cuda110: $(BUILD_DIR) env
+ $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
+cuda11x: $(BUILD_DIR) env
+ $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
+ $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
+ $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
+
cpuonly: $(BUILD_DIR) env
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index a9233e2..ac85f88 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
return out
-def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
+def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
shapeA = SA[0]
shapeB = SB[0]
dimsA = len(shapeA)
@@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
elif dimsA == 3 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
- if row_scale is not None: assert row_scale.numel() == out.shape[0]
assert dimsB != 3, 'len(B.shape)==3 not supported'
assert A.device.type == 'cuda'
assert B.device.type == 'cuda'
@@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
ptrA = get_ptr(A)
ptrB = get_ptr(B)
ptrC = get_ptr(out)
- ptrRowScale = get_ptr(row_scale)
k = shapeA[-1]
lda = ct.c_int32(m*32)
@@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32
k = ct.c_int32(k)
has_error = 0
+ ptrRowScale = get_ptr(None)
if formatB == 'col_turing':
if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif row_scale is None:
- has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
- has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
elif formatB == 'col_ampere':
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
- elif row_scale is None:
- has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
- has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 1:
raise Exception('cublasLt ran into an error!')
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 79ad5de..d4eb56c 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS];
- //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD];
@@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
- __syncthreads();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];
+ __syncthreads();
+
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
if(SPARSE_DECOMP)
{
@@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
- __shared__ typename BlockExchange::TempStorage temp_storage;
// we load row after row from the base_position
// Load data row by row
@@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
-__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
+__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
@@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
- smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
+ smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
__syncthreads();
}
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index eda5ba0..2447494 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -107,7 +107,7 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
-template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
diff --git a/csrc/ops.cu b/csrc/ops.cu
index e6227ae..952894c 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -247,6 +247,8 @@ int roundoff(int v, int d) {
}
+#ifdef NO_CUBLASLT
+#else
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
@@ -266,7 +268,11 @@ template<int ORDER> cublasLtOrder_t get_order()
case COL_AMPERE:
return CUBLASLT_ORDER_COL32_2R_4R4;
break;
+ default:
+ break;
}
+
+ return CUBLASLT_ORDER_ROW;
}
template cublasLtOrder_t get_order<ROW>();
@@ -274,6 +280,7 @@ template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
+#endif
template<int ORDER> int get_leading_dim(int dim1, int dim2)
@@ -297,6 +304,9 @@ template<int ORDER> int get_leading_dim(int dim1, int dim2)
// 32*32 tiles
return 32*roundoff(dim1, 32);
break;
+ default:
+ return 0;
+ break;
}
}
@@ -306,7 +316,8 @@ template int get_leading_dim<COL32>(int dim1, int dim2);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
-
+#ifdef NO_CUBLASLT
+#else
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
@@ -345,6 +356,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
+#endif
}
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
@@ -358,6 +370,9 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
+#ifdef NO_CUBLASLT
+ return 0;
+#else
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
@@ -412,6 +427,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
printf("error detected");
return has_error;
+#endif
}
int fill_up_to_nearest_multiple(int value, int multiple)
@@ -523,6 +539,9 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
+#ifdef NO_CUBLASLT
+#else
+
cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;
@@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
+#endif
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index 2ecbaae..7356c11 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -82,7 +82,6 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
-#endif
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
@@ -132,10 +131,11 @@ void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_r
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+#endif
extern "C"
{
- #if BUILD_CUDA
+#if BUILD_CUDA
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
@@ -231,7 +231,7 @@ extern "C"
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
- { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+ { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
diff --git a/cuda_install.sh b/cuda_install.sh
new file mode 100644
index 0000000..856cbe5
--- /dev/null
+++ b/cuda_install.sh
@@ -0,0 +1,77 @@
+URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux
+URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux
+URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run
+URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run
+URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run
+URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
+URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
+URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run
+URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run
+URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run
+URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run
+URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
+
+
+CUDA_VERSION=$1
+BASE_PATH=$2
+
+if [[ -n "$CUDA_VERSION" ]]; then
+ if [[ "$CUDA_VERSION" -eq "92" ]]; then
+ URL=$URL92
+ FOLDER=cuda-9.2
+ elif [[ "$CUDA_VERSION" -eq "100" ]]; then
+ URL=$URL100
+ FOLDER=cuda-10.0
+ elif [[ "$CUDA_VERSION" -eq "101" ]]; then
+ URL=$URL101
+ FOLDER=cuda-10.1
+ elif [[ "$CUDA_VERSION" -eq "102" ]]; then
+ URL=$URL102
+ FOLDER=cuda-10.2
+ elif [[ "$CUDA_VERSION" -eq "110" ]]; then
+ URL=$URL110
+ FOLDER=cuda-11.0
+ elif [[ "$CUDA_VERSION" -eq "111" ]]; then
+ URL=$URL111
+ FOLDER=cuda-11.1
+ elif [[ "$CUDA_VERSION" -eq "112" ]]; then
+ URL=$URL112
+ FOLDER=cuda-11.2
+ elif [[ "$CUDA_VERSION" -eq "113" ]]; then
+ URL=$URL113
+ FOLDER=cuda-11.3
+ elif [[ "$CUDA_VERSION" -eq "114" ]]; then
+ URL=$URL114
+ FOLDER=cuda-11.4
+ elif [[ "$CUDA_VERSION" -eq "115" ]]; then
+ URL=$URL115
+ FOLDER=cuda-11.5
+ elif [[ "$CUDA_VERSION" -eq "116" ]]; then
+ URL=$URL116
+ FOLDER=cuda-11.6
+ elif [[ "$CUDA_VERSION" -eq "117" ]]; then
+ URL=$URL117
+ FOLDER=cuda-11.7
+ else
+ echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
+ fi
+else
+ echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
+fi
+
+FILE=$(basename $URL)
+
+if [[ -n "$CUDA_VERSION" ]]; then
+ echo $URL
+ echo $FILE
+ wget $URL
+ bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=$BASE_PATH/lib --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
+ echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
+ echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
+ source ~/.bashrc
+else
+ echo ""
+fi
+
+
+
diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh
index 6357e1d..664d40e 100644
--- a/deploy_from_slurm.sh
+++ b/deploy_from_slurm.sh
@@ -1,86 +1,261 @@
#!/bin/bash
+BASE_PATH=$1
+
module unload cuda
module unload gcc
+#rm -rf dist build
+#make clean
+#make cleaneggs
+#export CUDA_HOME=
+#make cpuonly
+#
+#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+# # Control will enter here if $DIRECTORY doesn't exist.
+# echo "Compilation unsuccessul!" 1>&2
+# exit 64
+#fi
+#CUDA_VERSION=cpu python -m build
+#python -m twine upload dist/* --verbose --repository testpypi
+
rm -rf dist build
make clean
make cleaneggs
-module load cuda/9.2
-module load gcc/7.3.0
-CUDA_HOME=/public/apps/cuda/9.2
-make
-CUDA_VERSION=92 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
+export CUDA_HOME=$BASE_PATH/cuda-11.0
+make cuda110
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=110 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
rm -rf dist build
make clean
make cleaneggs
-module load cuda/10.0
-CUDA_HOME=/public/apps/cuda/10.0
-make cuda10x
-CUDA_VERSION=100 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
-module unload gcc
-module load gcc/8.4
+export CUDA_HOME=$BASE_PATH/cuda-11.1
+make cuda11x
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=111 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
rm -rf dist build
make clean
make cleaneggs
-module load cuda/10.1
-CUDA_HOME=/public/apps/cuda/10.1
-make cuda10x
-CUDA_VERSION=101 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
+export CUDA_HOME=$BASE_PATH/cuda-11.2
+make cuda11x
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=112 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
rm -rf dist build
make clean
make cleaneggs
-module load cuda/10.2
-CUDA_HOME=/public/apps/cuda/10.2/
-make cuda10x
-CUDA_VERSION=102 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
+export CUDA_HOME=$BASE_PATH/cuda-11.3
+make cuda11x
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=113 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
rm -rf dist build
make clean
make cleaneggs
-module load cuda/11.0
-CUDA_HOME=/public/apps/cuda/11.0
-make cuda110
-CUDA_VERSION=110 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
+export CUDA_HOME=$BASE_PATH/cuda-11.4
+make cuda11x
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=114 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
rm -rf dist build
make clean
make cleaneggs
-module load cuda/11.1
-CUDA_HOME=/public/apps/cuda/11.1
+export CUDA_HOME=$BASE_PATH/cuda-11.5
make cuda11x
-CUDA_VERSION=111 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=115 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+#rm -rf dist build
+#make clean
+#make cleaneggs
+#export CUDA_HOME=$BASE_PATH/cuda-11.6
+#
+#make cuda11x
+#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+# # Control will enter here if $DIRECTORY doesn't exist.
+# echo "Compilation unsuccessul!" 1>&2
+# exit 64
+#fi
+#CUDA_VERSION=116 python -m build
+#python -m twine upload dist/* --verbose --repository testpypi
+#
rm -rf dist build
make clean
make cleaneggs
-module load cuda/11.2
-CUDA_HOME=/public/apps/cuda/11.2
+export CUDA_HOME=$BASE_PATH/cuda-11.7
make cuda11x
-CUDA_VERSION=112 python -m build
-python -m twine upload dist/* --verbose
-module unload cuda
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=117 python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
rm -rf dist build
make clean
make cleaneggs
-CUDA_HOME=/private/home/timdettmers/git/autoswap/local/cuda-11.3 make cuda11x
-CUDA_VERSION=113 python -m build
+export CUDA_HOME=$BASE_PATH/cuda-10.2
+make cuda10x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=102-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.0
+make cuda110_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=110-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.1
+make cuda11x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=111-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.2
+make cuda11x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=112-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.3
+make cuda11x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=113-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.4
+make cuda11x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=114-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.5
+make cuda11x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=115-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.6
+
+make cuda11x_nomatmul
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=116-nomatmul python -m build
+python -m twine upload dist/* --verbose --repository testpypi
+
+rm -rf dist build
+make clean
+make cleaneggs
+export CUDA_HOME=$BASE_PATH/cuda-11.7
+make cuda11x_nomatmul
+
+if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then
+ # Control will enter here if $DIRECTORY doesn't exist.
+ echo "Compilation unsuccessul!" 1>&2
+ exit 64
+fi
+CUDA_VERSION=117-nomatmul python -m build
python -m twine upload dist/* --verbose
-module unload cuda
+python -m twine upload dist/* --verbose --repository testpypi
diff --git a/quicktest.py b/quicktest.py
new file mode 100644
index 0000000..2db6afa
--- /dev/null
+++ b/quicktest.py
@@ -0,0 +1,90 @@
+import torch
+import bitsandbytes as bnb
+import bitsandbytes.functional as F
+
+from itertools import product
+
+def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
+ k = 25
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ C1 = torch.matmul(A.float(), B.t().float())
+
+ A2, SA = F.transform(A, 'col32')
+ B2, SB = F.transform(B, 'colx')
+ if dims == 2:
+ C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ else:
+ C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ F.igemmlt(A2, B2, C2, SA, SB, SC)
+ C3, S = F.transform(C2, 'row', state=SC)
+ #torch.testing.assert_allclose(C1, C3.float())
+ #print(C1)
+ #print(C2)
+ #print(C3)
+ allclose = torch.allclose(C1, C3.float())
+ if allclose:
+ print(C1)
+ print(C2)
+ print(C3)
+
+ ## transposed
+ #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ #if dims == 2:
+ # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(A.float(), B.float().t())
+ #elif dims == 3:
+ # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(B.float(), A.t().float())
+ # C1 = C1.permute([2, 0, 1])
+
+ #A2, SA = F.transform(A, 'col32')
+ #B2, SB = F.transform(B, 'colx')
+ #if dims == 2:
+ # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
+ #else:
+ # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
+ # state = (C2.shape, 'row', A.shape[0])
+ # C2, SC = F.transform(C2, 'col32', state=state)
+ #F.igemmlt(A2, B2, C2, SA, SB, SC)
+ #C3, S = F.transform(C2, 'row', state=SC, ld=[0])
+ #torch.testing.assert_allclose(C1, C3.float())
+
+ ## weight update
+ #if dims == 3:
+ # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
+ # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
+
+ # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx')
+ # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32')
+ # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda')
+ # C2, SC = F.transform(C2, 'col32')
+ # F.igemmlt(B2, A2, C2, SB, SA, SC)
+ # C3, S = F.transform(C2, 'row', state=SC)
+ # torch.testing.assert_allclose(C1, C3.float())
+
+
+dims = (2, 3)
+ldb = [0]
+
+n = 2
+dim1 = torch.randint(1,256, size=(n,)).tolist()
+dim2 = torch.randint(32,512, size=(n,)).tolist()
+dim3 = torch.randint(32,1024, size=(n,)).tolist()
+dim4 = torch.randint(32,1024, size=(n,)).tolist()
+values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
+
+for ldb in range(32, 4096, 32):
+#for ldb in [None]:
+ val = test_igemmlt(2, 2, 2, 2, 2, ldb)
+ if val:
+ print(val, ldb)
+ else:
+ print('nope', ldb)
+#for val in values:
+ #test_igemmlt(*val)
diff --git a/setup.py b/setup.py
index 2402c02..6275ddd 100644
--- a/setup.py
+++ b/setup.py
@@ -11,13 +11,14 @@ def read(fname):
version = os.getenv("CUDA_VERSION", "cpu")
+prefix = '' if version == 'cpu' else 'cuda'
setup(
- name="bitsandbytes",
- version=f"0.26.0+{version}",
+ name=f"bitsandbytes-{prefix}{version}",
+ version=f"0.30.0",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
- description="8-bit optimizers and quantization routines.",
+ description="8-bit optimizers and matrix multiplication routines.",
license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression",
url="http://packages.python.org/bitsandbytes",
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 2d58fac..bfc3e28 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+@pytest.mark.skip("Row scale has some bugs for ampere")
def test_igemmlt_row_scale(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
@@ -1064,6 +1065,7 @@ dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner))
names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+@pytest.mark.skip("Row scale has some bugs for ampere")
def test_row_scale_bench(dim1, dim4, inner):
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
@@ -1183,6 +1185,7 @@ def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
def test_overflow():
formatB = F.get_special_format_str()
+ print(formatB)
for i in range(2):
a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )