From 7d2ecd30c044840ba5f161ec73e5eaf30ac8131d Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 15:21:37 -0700 Subject: Fixed rowcol synchronization bug. --- csrc/kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 1c3e723..4e744fb 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1768,7 +1768,6 @@ template Date: Mon, 25 Jul 2022 14:02:14 -0700 Subject: Fixed makefile; fixed Ampere igemmlt_8 bug. --- CHANGELOG.md | 11 ++++++ Makefile | 2 +- csrc/pythonInterface.c | 2 +- cuda_install_111.sh | 38 ++++++++++++++++++++ quicktest.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_functional.py | 1 + 6 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 cuda_install_111.sh create mode 100644 quicktest.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fa20b15..08adfce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,3 +53,14 @@ Bug fixes: Docs: - Added instructions how to solve "\_\_fatbinwrap_" errors. + + +### 0.30.0 + +#### 8-bit Inference Update + +Features: + - Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt) + - Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation + - Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations + - CPU only build now available (Thank you, @mryab) diff --git a/Makefile b/Makefile index b58e233..728c8e1 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib +LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 03c8d92..9b57549 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -228,7 +228,7 @@ extern "C" { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ diff --git a/cuda_install_111.sh b/cuda_install_111.sh new file mode 100644 index 0000000..476ab59 --- /dev/null +++ b/cuda_install_111.sh @@ -0,0 +1,38 @@ +FILE115=:cuda_11.5.1_495.29.05_linux.run +FILE111=:cuda_11.1.1_455.32.00_linux.run +URL115=:https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run +URL111=:https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run + + +CUDA_VERSION=$1 + +if [[ -n "$CUDA_VERSION" ]]; then + if [[ "$CUDA_VERSION" -eq "111" ]]; then + FILE=cuda_11.1.1_455.32.00_linux.run + URL=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run + FOLDER=cuda-11.1 + elif [[ "$CUDA_VERSION" -eq "115" ]]; then + FILE=cuda_11.5.1_495.29.05_linux.run + URL=https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run + FOLDER=cuda-11.5 + else + echo "argument error: No cuda version passed as input. Choose among: {111, 115}" + fi +else + echo "argument error: No cuda version passed as input. Choose among: {111, 115}" +fi + +if [[ -n "$CUDA_VERSION" ]]; then + echo $URL + echo $FILE + wget $URL + bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/$FOLDER/ --toolkit --silent + echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/$FOLDER/lib64/" >> ~/.bashrc + echo "export PATH=$PATH:~/local/$FOLDER/bin/" >> ~/.bashrc + source ~/.bashrc +else + echo "" +fi + + + diff --git a/quicktest.py b/quicktest.py new file mode 100644 index 0000000..2db6afa --- /dev/null +++ b/quicktest.py @@ -0,0 +1,90 @@ +import torch +import bitsandbytes as bnb +import bitsandbytes.functional as F + +from itertools import product + +def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): + k = 25 + for i in range(k): + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + C1 = torch.matmul(A.float(), B.t().float()) + + A2, SA = F.transform(A, 'col32') + B2, SB = F.transform(B, 'colx') + if dims == 2: + C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + else: + C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + F.igemmlt(A2, B2, C2, SA, SB, SC) + C3, S = F.transform(C2, 'row', state=SC) + #torch.testing.assert_allclose(C1, C3.float()) + #print(C1) + #print(C2) + #print(C3) + allclose = torch.allclose(C1, C3.float()) + if allclose: + print(C1) + print(C2) + print(C3) + + ## transposed + #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + #if dims == 2: + # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + # C1 = torch.matmul(A.float(), B.float().t()) + #elif dims == 3: + # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + # C1 = torch.matmul(B.float(), A.t().float()) + # C1 = C1.permute([2, 0, 1]) + + #A2, SA = F.transform(A, 'col32') + #B2, SB = F.transform(B, 'colx') + #if dims == 2: + # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + #else: + # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') + # state = (C2.shape, 'row', A.shape[0]) + # C2, SC = F.transform(C2, 'col32', state=state) + #F.igemmlt(A2, B2, C2, SA, SB, SC) + #C3, S = F.transform(C2, 'row', state=SC, ld=[0]) + #torch.testing.assert_allclose(C1, C3.float()) + + ## weight update + #if dims == 3: + # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) + # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) + + # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx') + # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32') + # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda') + # C2, SC = F.transform(C2, 'col32') + # F.igemmlt(B2, A2, C2, SB, SA, SC) + # C3, S = F.transform(C2, 'row', state=SC) + # torch.testing.assert_allclose(C1, C3.float()) + + +dims = (2, 3) +ldb = [0] + +n = 2 +dim1 = torch.randint(1,256, size=(n,)).tolist() +dim2 = torch.randint(32,512, size=(n,)).tolist() +dim3 = torch.randint(32,1024, size=(n,)).tolist() +dim4 = torch.randint(32,1024, size=(n,)).tolist() +values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) + +for ldb in range(32, 4096, 32): +#for ldb in [None]: + val = test_igemmlt(2, 2, 2, 2, 2, ldb) + if val: + print(val, ldb) + else: + print('nope', ldb) +#for val in values: + #test_igemmlt(*val) diff --git a/tests/test_functional.py b/tests/test_functional.py index 6cbe58f..dcb0255 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1183,6 +1183,7 @@ def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): def test_overflow(): formatB = F.get_special_format_str() + print(formatB) for i in range(2): a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) -- cgit v1.2.3 From 1e88edd8c096bde5202dd61411d3c8d7eda56645 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 17:27:57 -0700 Subject: Removed rowscale (segfaults on ampere). --- Makefile | 1 - bitsandbytes/functional.py | 13 ++++--------- tests/test_functional.py | 2 ++ 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index 728c8e1..2e1d265 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,6 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 806c254..0190a7e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -897,7 +897,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32): +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -917,7 +917,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 elif dimsA == 3 and out is None: out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') - if row_scale is not None: assert row_scale.numel() == out.shape[0] assert dimsB != 3, 'len(B.shape)==3 not supported' assert A.device.type == 'cuda' assert B.device.type == 'cuda' @@ -936,7 +935,6 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 ptrA = get_ptr(A) ptrB = get_ptr(B) ptrC = get_ptr(out) - ptrRowScale = get_ptr(row_scale) k = shapeA[-1] lda = ct.c_int32(m*32) @@ -955,20 +953,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32 k = ct.c_int32(k) has_error = 0 + ptrRowScale = get_ptr(None) if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == 'col_ampere': if dtype == torch.int32: has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif row_scale is None: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 1: raise Exception('cublasLt ran into an error!') diff --git a/tests/test_functional.py b/tests/test_functional.py index dcb0255..d80a4f9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -992,6 +992,7 @@ inner = torch.randint(1,4*1024, size=(n,)).tolist() values = list(zip(dim1, dim4, inner)) names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): formatB = F.get_special_format_str() err1, err2, err3 = [], [], [] @@ -1064,6 +1065,7 @@ dim4 = [12288, 4096] values = list(zip(dim1, dim4, inner)) names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.skip("Row scale has some bugs for ampere") def test_row_scale_bench(dim1, dim4, inner): err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] -- cgit v1.2.3 From 9268dc9d887a3d54cd1f008dcb628aaa5b5bd90a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 19:30:37 -0700 Subject: Some progress on build script; added multi-cuda install script. --- CHANGELOG.md | 3 ++ Makefile | 7 +-- csrc/kernels.cu | 17 ++++--- csrc/kernels.cuh | 2 +- csrc/ops.cu | 22 ++++++++- cuda_install.sh | 77 +++++++++++++++++++++++++++++++ cuda_install_111.sh | 38 ---------------- deploy_from_slurm.sh | 125 ++++++++++++++++++++++++++++++++------------------- 8 files changed, 192 insertions(+), 99 deletions(-) create mode 100644 cuda_install.sh delete mode 100644 cuda_install_111.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 08adfce..285984e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,3 +64,6 @@ Features: - Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation - Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations - CPU only build now available (Thank you, @mryab) + +Deprecated: + - Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available diff --git a/Makefile b/Makefile index 2e1d265..328faa5 100644 --- a/Makefile +++ b/Makefile @@ -27,13 +27,14 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 # Later versions of CUDA support the new architectures CC_CUDA10x := -gencode arch=compute_30,code=sm_30 -CC_CUDA10x += -gencode arch=compute_75,code=sm_75 +CC_CUDA10x := -gencode arch=compute_75,code=sm_75 CC_CUDA110 := -gencode arch=compute_75,code=sm_75 CC_CUDA110 += -gencode arch=compute_80,code=sm_80 @@ -43,12 +44,12 @@ CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_86,code=sm_86 all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 4e744fb..6eca3aa 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2166,7 +2166,6 @@ template BlockExchange; - __shared__ typename BlockExchange::TempStorage temp_storage; // we load row after row from the base_position // Load data row by row @@ -2446,7 +2445,7 @@ template -__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block @@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o { for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]); + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; __syncthreads(); } @@ -2596,12 +2595,12 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o // TEMPLATE DEFINITIONS //============================================================== -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index cbfbeba..4e65e96 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -107,7 +107,7 @@ template __global__ void kPercentileCl __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); -template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kdequant_mm_int32_fp16( int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, diff --git a/csrc/ops.cu b/csrc/ops.cu index 8946015..c430d55 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -247,6 +247,8 @@ int roundoff(int v, int d) { } +#ifdef NO_CUBLASLT +#else template cublasLtOrder_t get_order() { switch(ORDER) @@ -266,7 +268,11 @@ template cublasLtOrder_t get_order() case COL_AMPERE: return CUBLASLT_ORDER_COL32_2R_4R4; break; + default: + break; } + + return CUBLASLT_ORDER_ROW; } template cublasLtOrder_t get_order(); @@ -274,6 +280,7 @@ template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); template cublasLtOrder_t get_order(); +#endif template int get_leading_dim(int dim1, int dim2) @@ -297,6 +304,9 @@ template int get_leading_dim(int dim1, int dim2) // 32*32 tiles return 32*roundoff(dim1, 32); break; + default: + return 0; + break; } } @@ -306,7 +316,8 @@ template int get_leading_dim(int dim1, int dim2); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { - +#ifdef NO_CUBLASLT +#else cublasLtOrder_t orderA = get_order(); cublasLtOrder_t orderOut = get_order(); int ldA = get_leading_dim(dim1, dim2); @@ -345,6 +356,7 @@ template void trans if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif } template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -358,6 +370,9 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { +#ifdef NO_CUBLASLT + return 0; +#else int has_error = 0; cublasLtMatmulDesc_t matmulDesc = NULL; cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; @@ -412,6 +427,7 @@ template int igemmlt(cublasLtHandle printf("error detected"); return has_error; +#endif } int fill_up_to_nearest_multiple(int value, int multiple) @@ -523,6 +539,9 @@ template void transformRowToFormat(char * A, char *o void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { +#ifdef NO_CUBLASLT +#else + cusparseSpMatDescr_t descA; cusparseDnMatDescr_t descB, descC; @@ -569,6 +588,7 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +#endif } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) diff --git a/cuda_install.sh b/cuda_install.sh new file mode 100644 index 0000000..856cbe5 --- /dev/null +++ b/cuda_install.sh @@ -0,0 +1,77 @@ +URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux +URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux +URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run +URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run +URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run +URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run +URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run +URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run +URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run +URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run +URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run +URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run + + +CUDA_VERSION=$1 +BASE_PATH=$2 + +if [[ -n "$CUDA_VERSION" ]]; then + if [[ "$CUDA_VERSION" -eq "92" ]]; then + URL=$URL92 + FOLDER=cuda-9.2 + elif [[ "$CUDA_VERSION" -eq "100" ]]; then + URL=$URL100 + FOLDER=cuda-10.0 + elif [[ "$CUDA_VERSION" -eq "101" ]]; then + URL=$URL101 + FOLDER=cuda-10.1 + elif [[ "$CUDA_VERSION" -eq "102" ]]; then + URL=$URL102 + FOLDER=cuda-10.2 + elif [[ "$CUDA_VERSION" -eq "110" ]]; then + URL=$URL110 + FOLDER=cuda-11.0 + elif [[ "$CUDA_VERSION" -eq "111" ]]; then + URL=$URL111 + FOLDER=cuda-11.1 + elif [[ "$CUDA_VERSION" -eq "112" ]]; then + URL=$URL112 + FOLDER=cuda-11.2 + elif [[ "$CUDA_VERSION" -eq "113" ]]; then + URL=$URL113 + FOLDER=cuda-11.3 + elif [[ "$CUDA_VERSION" -eq "114" ]]; then + URL=$URL114 + FOLDER=cuda-11.4 + elif [[ "$CUDA_VERSION" -eq "115" ]]; then + URL=$URL115 + FOLDER=cuda-11.5 + elif [[ "$CUDA_VERSION" -eq "116" ]]; then + URL=$URL116 + FOLDER=cuda-11.6 + elif [[ "$CUDA_VERSION" -eq "117" ]]; then + URL=$URL117 + FOLDER=cuda-11.7 + else + echo "argument error: No cuda version passed as input. Choose among: {111, 115}" + fi +else + echo "argument error: No cuda version passed as input. Choose among: {111, 115}" +fi + +FILE=$(basename $URL) + +if [[ -n "$CUDA_VERSION" ]]; then + echo $URL + echo $FILE + wget $URL + bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=$BASE_PATH/lib --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent + echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc + echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc + source ~/.bashrc +else + echo "" +fi + + + diff --git a/cuda_install_111.sh b/cuda_install_111.sh deleted file mode 100644 index 476ab59..0000000 --- a/cuda_install_111.sh +++ /dev/null @@ -1,38 +0,0 @@ -FILE115=:cuda_11.5.1_495.29.05_linux.run -FILE111=:cuda_11.1.1_455.32.00_linux.run -URL115=:https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run -URL111=:https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run - - -CUDA_VERSION=$1 - -if [[ -n "$CUDA_VERSION" ]]; then - if [[ "$CUDA_VERSION" -eq "111" ]]; then - FILE=cuda_11.1.1_455.32.00_linux.run - URL=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run - FOLDER=cuda-11.1 - elif [[ "$CUDA_VERSION" -eq "115" ]]; then - FILE=cuda_11.5.1_495.29.05_linux.run - URL=https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run - FOLDER=cuda-11.5 - else - echo "argument error: No cuda version passed as input. Choose among: {111, 115}" - fi -else - echo "argument error: No cuda version passed as input. Choose among: {111, 115}" -fi - -if [[ -n "$CUDA_VERSION" ]]; then - echo $URL - echo $FILE - wget $URL - bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/$FOLDER/ --toolkit --silent - echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/$FOLDER/lib64/" >> ~/.bashrc - echo "export PATH=$PATH:~/local/$FOLDER/bin/" >> ~/.bashrc - source ~/.bashrc -else - echo "" -fi - - - diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh index 6357e1d..5a554bb 100644 --- a/deploy_from_slurm.sh +++ b/deploy_from_slurm.sh @@ -1,86 +1,117 @@ #!/bin/bash +BASE_PATH=$1 + module unload cuda module unload gcc rm -rf dist build make clean make cleaneggs -module load cuda/9.2 -module load gcc/7.3.0 -CUDA_HOME=/public/apps/cuda/9.2 -make -CUDA_VERSION=92 python -m build -python -m twine upload dist/* --verbose -module unload cuda +export CUDA_HOME=$BASE_PATH/cuda-11.0 +make cuda110 +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=110 python -m build +#python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -module load cuda/10.0 -CUDA_HOME=/public/apps/cuda/10.0 -make cuda10x -CUDA_VERSION=100 python -m build -python -m twine upload dist/* --verbose -module unload cuda -module unload gcc -module load gcc/8.4 +export CUDA_HOME=$BASE_PATH/cuda-11.1 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=111 python -m build +#python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -module load cuda/10.1 -CUDA_HOME=/public/apps/cuda/10.1 -make cuda10x -CUDA_VERSION=101 python -m build -python -m twine upload dist/* --verbose -module unload cuda +export CUDA_HOME=$BASE_PATH/cuda-11.2 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=112 python -m build +#python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -module load cuda/10.2 -CUDA_HOME=/public/apps/cuda/10.2/ -make cuda10x -CUDA_VERSION=102 python -m build -python -m twine upload dist/* --verbose -module unload cuda +export CUDA_HOME=$BASE_PATH/cuda-11.3 +make cuda11x +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=113 python -m build +#python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -module load cuda/11.0 -CUDA_HOME=/public/apps/cuda/11.0 -make cuda110 -CUDA_VERSION=110 python -m build -python -m twine upload dist/* --verbose -module unload cuda +export CUDA_HOME=$BASE_PATH/cuda-11.4 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=114 python -m build +##python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -module load cuda/11.1 -CUDA_HOME=/public/apps/cuda/11.1 +export CUDA_HOME=$BASE_PATH/cuda-11.5 make cuda11x -CUDA_VERSION=111 python -m build -python -m twine upload dist/* --verbose -module unload cuda + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=115 python -m build +#python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -module load cuda/11.2 -CUDA_HOME=/public/apps/cuda/11.2 +export CUDA_HOME=$BASE_PATH/cuda-11.6 + make cuda11x -CUDA_VERSION=112 python -m build -python -m twine upload dist/* --verbose -module unload cuda +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=116 python -m build +#python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs -CUDA_HOME=/private/home/timdettmers/git/autoswap/local/cuda-11.3 make cuda11x -CUDA_VERSION=113 python -m build -python -m twine upload dist/* --verbose -module unload cuda +export CUDA_HOME=$BASE_PATH/cuda-11.7 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=117 python -m build +#python -m twine upload dist/* --verbose -- cgit v1.2.3 From f2dd703251aaff826a85c7f77624dfe5cbc91c6c Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 22:34:14 -0700 Subject: Added matmul build and flags. --- Makefile | 33 ++++++++--- deploy_from_slurm.sh | 161 ++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 170 insertions(+), 24 deletions(-) diff --git a/Makefile b/Makefile index 328faa5..10f267a 100644 --- a/Makefile +++ b/Makefile @@ -27,14 +27,13 @@ COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta -COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 # Later versions of CUDA support the new architectures CC_CUDA10x := -gencode arch=compute_30,code=sm_30 -CC_CUDA10x := -gencode arch=compute_75,code=sm_75 +CC_CUDA10x += -gencode arch=compute_75,code=sm_75 CC_CUDA110 := -gencode arch=compute_75,code=sm_75 CC_CUDA110 += -gencode arch=compute_80,code=sm_80 @@ -43,6 +42,14 @@ CC_CUDA11x := -gencode arch=compute_75,code=sm_75 CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_86,code=sm_86 +CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 +CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 + +CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 +CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 +CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 + + all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o @@ -53,21 +60,31 @@ cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) -cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) +cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) -cuda110: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) +cuda110_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) -cuda11x: $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) +cuda11x_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) +cuda110: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + +cuda11x: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) + cpuonly: $(BUILD_DIR) env $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh index 5a554bb..93233a4 100644 --- a/deploy_from_slurm.sh +++ b/deploy_from_slurm.sh @@ -4,88 +4,217 @@ BASE_PATH=$1 module unload cuda module unload gcc +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.0 +#make cuda110 +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=110 python -m build +##python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.1 +#make cuda11x +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=111 python -m build +##python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.2 +#make cuda11x +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=112 python -m build +##python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.3 +#make cuda11x +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=113 python -m build +##python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.4 +#make cuda11x +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=114 python -m build +###python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.5 +#make cuda11x +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=115 python -m build +##python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.6 +# +#make cuda11x +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=116 python -m build +##python -m twine upload dist/* --verbose +# +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.7 +#make cuda11x +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +##CUDA_VERSION=117 python -m build +##python -m twine upload dist/* --verbose + + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-10.2 +make cuda10x_nomatmul + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +#CUDA_VERSION=102-nomatmul python -m build +#python -m twine upload dist/* --verbose + + rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.0 -make cuda110 +make cuda110_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=110 python -m build +#CUDA_VERSION=110-nomatmul python -m build #python -m twine upload dist/* --verbose + rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.1 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=111 python -m build +#CUDA_VERSION=111-nomatmul python -m build #python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.2 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=112 python -m build +#CUDA_VERSION=112-nomatmul python -m build #python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.3 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=113 python -m build +#CUDA_VERSION=113-nomatmul python -m build #python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.4 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=114 python -m build +#CUDA_VERSION=114-nomatmul python -m build ##python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.5 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=115 python -m build +#CUDA_VERSION=115-nomatmul python -m build #python -m twine upload dist/* --verbose rm -rf dist build @@ -93,25 +222,25 @@ make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.6 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=116 python -m build +#CUDA_VERSION=116-nomatmul python -m build #python -m twine upload dist/* --verbose rm -rf dist build make clean make cleaneggs export CUDA_HOME=$BASE_PATH/cuda-11.7 -make cuda11x +make cuda11x_nomatmul if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then # Control will enter here if $DIRECTORY doesn't exist. echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=117 python -m build +#CUDA_VERSION=117-nomatmul python -m build #python -m twine upload dist/* --verbose -- cgit v1.2.3 From 953b7285ddf55913732bc9f137953dd00ac64c35 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 09:12:16 -0700 Subject: Fixed cpuonly build. --- csrc/pythonInterface.c | 4 +- deploy_from_slurm.sh | 272 ++++++++++++++++++++++++++----------------------- setup.py | 2 +- 3 files changed, 146 insertions(+), 132 deletions(-) diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 9b57549..a6a4b13 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -82,7 +82,6 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } -#endif #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -129,10 +128,11 @@ void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_r void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } +#endif extern "C" { - #if BUILD_CUDA +#if BUILD_CUDA void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh index 93233a4..d58aa38 100644 --- a/deploy_from_slurm.sh +++ b/deploy_from_slurm.sh @@ -4,117 +4,131 @@ BASE_PATH=$1 module unload cuda module unload gcc -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.0 -#make cuda110 -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=110 python -m build -##python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.1 -#make cuda11x -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=111 python -m build -##python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.2 -#make cuda11x -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=112 python -m build -##python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.3 -#make cuda11x -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=113 python -m build -##python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.4 -#make cuda11x -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=114 python -m build -###python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.5 -#make cuda11x -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=115 python -m build -##python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.6 -# -#make cuda11x -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=116 python -m build -##python -m twine upload dist/* --verbose -# -#rm -rf dist build -#make clean -#make cleaneggs -#export CUDA_HOME=$BASE_PATH/cuda-11.7 -#make cuda11x -# -#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then -# # Control will enter here if $DIRECTORY doesn't exist. -# echo "Compilation unsuccessul!" 1>&2 -# exit 64 -#fi -##CUDA_VERSION=117 python -m build -##python -m twine upload dist/* --verbose +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME= +make cpuonly + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=cpu python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.0 +make cuda110 + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=110 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.1 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=111 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.2 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=112 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.3 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=113 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.4 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=114 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.5 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=115 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.6 + +make cuda11x +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=116 python -m build +python -m twine upload dist/* --verbose + +rm -rf dist build +make clean +make cleaneggs +export CUDA_HOME=$BASE_PATH/cuda-11.7 +make cuda11x + +if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi +CUDA_VERSION=117 python -m build +python -m twine upload dist/* --verbose rm -rf dist build @@ -128,8 +142,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=102-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=102-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build @@ -143,8 +157,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=110-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=110-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build @@ -158,8 +172,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=111-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=111-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build make clean @@ -172,8 +186,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=112-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=112-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build make clean @@ -186,8 +200,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=113-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=113-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build make clean @@ -200,8 +214,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=114-nomatmul python -m build -##python -m twine upload dist/* --verbose +CUDA_VERSION=114-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build make clean @@ -214,8 +228,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=115-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=115-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build make clean @@ -228,8 +242,8 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=116-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=116-nomatmul python -m build +python -m twine upload dist/* --verbose rm -rf dist build make clean @@ -242,5 +256,5 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then echo "Compilation unsuccessul!" 1>&2 exit 64 fi -#CUDA_VERSION=117-nomatmul python -m build -#python -m twine upload dist/* --verbose +CUDA_VERSION=117-nomatmul python -m build +python -m twine upload dist/* --verbose diff --git a/setup.py b/setup.py index 2402c02..6cc091b 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ version = os.getenv("CUDA_VERSION", "cpu") setup( name="bitsandbytes", - version=f"0.26.0+{version}", + version=f"0.30.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and quantization routines.", -- cgit v1.2.3 From dc8c9efdb33130f960adc864916b67d0cb744dbb Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 26 Jul 2022 10:32:22 -0700 Subject: Changed setup.py; deployed on test pypi. --- deploy_from_slurm.sh | 87 ++++++++++++++++++++++++++-------------------------- setup.py | 5 +-- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/deploy_from_slurm.sh b/deploy_from_slurm.sh index d58aa38..664d40e 100644 --- a/deploy_from_slurm.sh +++ b/deploy_from_slurm.sh @@ -4,19 +4,19 @@ BASE_PATH=$1 module unload cuda module unload gcc -rm -rf dist build -make clean -make cleaneggs -export CUDA_HOME= -make cpuonly - -if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then - # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 - exit 64 -fi -CUDA_VERSION=cpu python -m build -python -m twine upload dist/* --verbose +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME= +#make cpuonly +# +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +#CUDA_VERSION=cpu python -m build +#python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -30,7 +30,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=110 python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -44,7 +44,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=111 python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -58,7 +58,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=112 python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -72,7 +72,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=113 python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -86,7 +86,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=114 python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -100,22 +100,22 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=115 python -m build -python -m twine upload dist/* --verbose - -rm -rf dist build -make clean -make cleaneggs -export CUDA_HOME=$BASE_PATH/cuda-11.6 - -make cuda11x -if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then - # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 - exit 64 -fi -CUDA_VERSION=116 python -m build -python -m twine upload dist/* --verbose - +python -m twine upload dist/* --verbose --repository testpypi + +#rm -rf dist build +#make clean +#make cleaneggs +#export CUDA_HOME=$BASE_PATH/cuda-11.6 +# +#make cuda11x +#if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then +# # Control will enter here if $DIRECTORY doesn't exist. +# echo "Compilation unsuccessul!" 1>&2 +# exit 64 +#fi +#CUDA_VERSION=116 python -m build +#python -m twine upload dist/* --verbose --repository testpypi +# rm -rf dist build make clean make cleaneggs @@ -128,7 +128,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=117 python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build @@ -143,7 +143,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=102-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build @@ -158,7 +158,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=110-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build @@ -173,7 +173,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=111-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -187,7 +187,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=112-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -201,7 +201,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=113-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -215,7 +215,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=114-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -229,7 +229,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=115-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -243,7 +243,7 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then exit 64 fi CUDA_VERSION=116-nomatmul python -m build -python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi rm -rf dist build make clean @@ -258,3 +258,4 @@ if [ ! -f "./bitsandbytes/libbitsandbytes.so" ]; then fi CUDA_VERSION=117-nomatmul python -m build python -m twine upload dist/* --verbose +python -m twine upload dist/* --verbose --repository testpypi diff --git a/setup.py b/setup.py index 6cc091b..6275ddd 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,14 @@ def read(fname): version = os.getenv("CUDA_VERSION", "cpu") +prefix = '' if version == 'cpu' else 'cuda' setup( - name="bitsandbytes", + name=f"bitsandbytes-{prefix}{version}", version=f"0.30.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", - description="8-bit optimizers and quantization routines.", + description="8-bit optimizers and matrix multiplication routines.", license="MIT", keywords="gpu optimizers optimization 8-bit quantization compression", url="http://packages.python.org/bitsandbytes", -- cgit v1.2.3