From 8b1fd32e3e4f5073fd055cb5f9261ec585f8cc2c Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 25 Jul 2022 14:02:14 -0700 Subject: Fixed makefile; fixed Ampere igemmlt_8 bug. --- CHANGELOG.md | 11 ++++++ Makefile | 2 +- csrc/pythonInterface.c | 2 +- cuda_install_111.sh | 38 ++++++++++++++++++++ quicktest.py | 90 ++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_functional.py | 1 + 6 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 cuda_install_111.sh create mode 100644 quicktest.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fa20b15..08adfce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,3 +53,14 @@ Bug fixes: Docs: - Added instructions how to solve "\_\_fatbinwrap_" errors. + + +### 0.30.0 + +#### 8-bit Inference Update + +Features: + - Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt) + - Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation + - Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations + - CPU only build now available (Thank you, @mryab) diff --git a/Makefile b/Makefile index b58e233..728c8e1 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib +LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 03c8d92..9b57549 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -228,7 +228,7 @@ extern "C" { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) - { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ diff --git a/cuda_install_111.sh b/cuda_install_111.sh new file mode 100644 index 0000000..476ab59 --- /dev/null +++ b/cuda_install_111.sh @@ -0,0 +1,38 @@ +FILE115=:cuda_11.5.1_495.29.05_linux.run +FILE111=:cuda_11.1.1_455.32.00_linux.run +URL115=:https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run +URL111=:https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run + + +CUDA_VERSION=$1 + +if [[ -n "$CUDA_VERSION" ]]; then + if [[ "$CUDA_VERSION" -eq "111" ]]; then + FILE=cuda_11.1.1_455.32.00_linux.run + URL=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run + FOLDER=cuda-11.1 + elif [[ "$CUDA_VERSION" -eq "115" ]]; then + FILE=cuda_11.5.1_495.29.05_linux.run + URL=https://developer.download.nvidia.com/compute/cuda/11.5.1/local_installers/cuda_11.5.1_495.29.05_linux.run + FOLDER=cuda-11.5 + else + echo "argument error: No cuda version passed as input. Choose among: {111, 115}" + fi +else + echo "argument error: No cuda version passed as input. Choose among: {111, 115}" +fi + +if [[ -n "$CUDA_VERSION" ]]; then + echo $URL + echo $FILE + wget $URL + bash $FILE --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/$FOLDER/ --toolkit --silent + echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/$FOLDER/lib64/" >> ~/.bashrc + echo "export PATH=$PATH:~/local/$FOLDER/bin/" >> ~/.bashrc + source ~/.bashrc +else + echo "" +fi + + + diff --git a/quicktest.py b/quicktest.py new file mode 100644 index 0000000..2db6afa --- /dev/null +++ b/quicktest.py @@ -0,0 +1,90 @@ +import torch +import bitsandbytes as bnb +import bitsandbytes.functional as F + +from itertools import product + +def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): + k = 25 + for i in range(k): + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + C1 = torch.matmul(A.float(), B.t().float()) + + A2, SA = F.transform(A, 'col32') + B2, SB = F.transform(B, 'colx') + if dims == 2: + C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + else: + C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + F.igemmlt(A2, B2, C2, SA, SB, SC) + C3, S = F.transform(C2, 'row', state=SC) + #torch.testing.assert_allclose(C1, C3.float()) + #print(C1) + #print(C2) + #print(C3) + allclose = torch.allclose(C1, C3.float()) + if allclose: + print(C1) + print(C2) + print(C3) + + ## transposed + #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + #if dims == 2: + # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + # C1 = torch.matmul(A.float(), B.float().t()) + #elif dims == 3: + # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + # C1 = torch.matmul(B.float(), A.t().float()) + # C1 = C1.permute([2, 0, 1]) + + #A2, SA = F.transform(A, 'col32') + #B2, SB = F.transform(B, 'colx') + #if dims == 2: + # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + #else: + # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') + # state = (C2.shape, 'row', A.shape[0]) + # C2, SC = F.transform(C2, 'col32', state=state) + #F.igemmlt(A2, B2, C2, SA, SB, SC) + #C3, S = F.transform(C2, 'row', state=SC, ld=[0]) + #torch.testing.assert_allclose(C1, C3.float()) + + ## weight update + #if dims == 3: + # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) + # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) + + # A2, SA = F.transform(A.view(-1, A.shape[-1]).t().contiguous(), 'colx') + # B2, SB = F.transform(B.view(-1, B.shape[-1]).t().contiguous(), 'col32') + # C2 = torch.zeros(B.shape[-1], A.shape[-1], dtype=torch.int32, device='cuda') + # C2, SC = F.transform(C2, 'col32') + # F.igemmlt(B2, A2, C2, SB, SA, SC) + # C3, S = F.transform(C2, 'row', state=SC) + # torch.testing.assert_allclose(C1, C3.float()) + + +dims = (2, 3) +ldb = [0] + +n = 2 +dim1 = torch.randint(1,256, size=(n,)).tolist() +dim2 = torch.randint(32,512, size=(n,)).tolist() +dim3 = torch.randint(32,1024, size=(n,)).tolist() +dim4 = torch.randint(32,1024, size=(n,)).tolist() +values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) + +for ldb in range(32, 4096, 32): +#for ldb in [None]: + val = test_igemmlt(2, 2, 2, 2, 2, ldb) + if val: + print(val, ldb) + else: + print('nope', ldb) +#for val in values: + #test_igemmlt(*val) diff --git a/tests/test_functional.py b/tests/test_functional.py index 6cbe58f..dcb0255 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1183,6 +1183,7 @@ def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): def test_overflow(): formatB = F.get_special_format_str() + print(formatB) for i in range(2): a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) -- cgit v1.2.3