summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--bitsandbytes/cextension.py27
-rw-r--r--bitsandbytes/cuda_setup.py84
-rw-r--r--install_cuda.sh5
-rw-r--r--tests/test_cuda_setup_evaluator.py20
4 files changed, 123 insertions, 13 deletions
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 2374c35..4bc7bf7 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -1,9 +1,34 @@
import ctypes as ct
import os
from warnings import warn
+from bitsandbytes.cuda_setup import evaluate_cuda_setup
-lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
+class CUDALibrary_Singleton(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.context = {}
+ binary_name = evaluate_cuda_setup()
+ if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'):
+ print(f'TODO: compile library for specific version: {binary_name}')
+ print('defaulting to libbitsandbytes.so')
+ self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
+ else:
+ self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}')
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+
+lib = CUDALibrary_Singleton.get_instance().lib
try:
lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p
diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py
index 48423b5..6f67275 100644
--- a/bitsandbytes/cuda_setup.py
+++ b/bitsandbytes/cuda_setup.py
@@ -23,6 +23,58 @@ from pathlib import Path
from typing import Set, Union
from .utils import warn_of_missing_prerequisite, print_err
+import ctypes
+import shlex
+import subprocess
+
+def execute_and_return(strCMD):
+ proc = subprocess.Popen(shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, err = proc.communicate()
+ out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip()
+ return out, err
+
+def check_cuda_result(cuda, result_val):
+ if result_val != 0:
+ cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
+ print(f"Count not initialize CUDA - failure!")
+ raise Exception('CUDA excepion!')
+ return result_val
+
+# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
+def get_compute_capability():
+ libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll')
+ for libname in libnames:
+ try:
+ cuda = ctypes.CDLL(libname)
+ except OSError:
+ continue
+ else:
+ break
+ else:
+ raise OSError("could not load any of: " + ' '.join(libnames))
+
+
+ nGpus = ctypes.c_int()
+ cc_major = ctypes.c_int()
+ cc_minor = ctypes.c_int()
+
+ result = ctypes.c_int()
+ device = ctypes.c_int()
+ context = ctypes.c_void_p()
+ error_str = ctypes.c_char_p()
+
+ result = check_cuda_result(cuda, cuda.cuInit(0))
+
+ result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
+ ccs = []
+ for i in range(nGpus.value):
+ result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
+ result = check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device))
+ ccs.append(f'{cc_major.value}.{cc_minor.value}')
+
+ #TODO: handle different compute capabilities; for now, take the max
+ ccs.sort()
+ return ccs[-1]
CUDA_RUNTIME_LIB: str = "libcudart.so"
@@ -72,12 +124,30 @@ def get_cuda_runtime_lib_path(
raise FileNotFoundError(err_msg)
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
- return ld_library_paths
+ return single_cuda_runtime_lib_dir
def evaluate_cuda_setup():
- # - if paths faulty, return meaningful error
- # - else:
- # - determine CUDA version
- # - determine capabilities
- # - based on that set the default path
- pass
+ cuda_path = get_cuda_runtime_lib_path()
+ cc = get_compute_capability()
+ binary_name = 'libbitsandbytes_cpu.so'
+
+ has_gpu = cc != ''
+ if not has_gpu:
+ print('WARNING: No GPU detected! Check our CUDA paths. Processding to load CPU-only library...')
+ return binary_name
+
+ has_cublaslt = cc in ['7.5', '8.0', '8.6']
+
+ # TODO:
+ # (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
+ # (2) Multiple CUDA versions installed
+
+ cuda_home = str(Path(cuda_path).parent.parent)
+ ls_output, err = execute_and_return(f'{cuda_home}/bin/nvcc --version')
+ cuda_version = ls_output.split('\n')[3].split(',')[-1].strip().lower().replace('v', '')
+ major, minor, revision = cuda_version.split('.')
+ cuda_version_string = f'{major}{minor}'
+
+ binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so'
+
+ return binary_name
diff --git a/install_cuda.sh b/install_cuda.sh
deleted file mode 100644
index 6a4ff0c..0000000
--- a/install_cuda.sh
+++ /dev/null
@@ -1,5 +0,0 @@
-wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run
-bash cuda_11.1.1_455.32.00_linux.run --no-drm --no-man-page --override --installpath=~/local --librarypath=~/local/lib --toolkitpath=~/local/cuda-11.1/ --toolkit --silent
-echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/local/cuda-11.1/lib64/" >> ~/.bashrc
-echo "export PATH=$PATH:~/local/cuda-11.1/bin/" >> ~/.bashrc
-source ~/.bashrc
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py
index 96ee6c5..72aa3c7 100644
--- a/tests/test_cuda_setup_evaluator.py
+++ b/tests/test_cuda_setup_evaluator.py
@@ -1,4 +1,5 @@
import pytest
+import os
from typing import List
@@ -16,6 +17,7 @@ HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
+ (f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"),
]
@@ -64,3 +66,21 @@ def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
match in std_err
for match in {"WARNING", "non-existent"}
)
+
+def test_full_system():
+ ## this only tests the cuda version and not compute capability
+ ld_path = os.environ['LD_LIBRARY_PATH']
+ paths = ld_path.split(':')
+ version = ''
+ for p in paths:
+ if 'cuda' in p:
+ idx = p.rfind('cuda-')
+ version = p[idx+5:idx+5+4].replace('/', '')
+ version = float(version)
+ break
+
+ binary_name = evaluate_cuda_setup()
+ binary_name = binary_name.replace('libbitsandbytes_cuda', '')
+ assert binary_name.startswith(str(version).replace('.', ''))
+
+