diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/__init__.py | 1 | ||||
-rw-r--r-- | bitsandbytes/cextension.py | 2 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/__init__.py | 0 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/compute_capability.py | 65 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/main.py (renamed from bitsandbytes/cuda_setup.py) | 4 |
5 files changed, 69 insertions, 3 deletions
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 6e5b6ac..76a5b48 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -12,6 +12,7 @@ from .autograd._functions import ( ) from .cextension import COMPILED_WITH_CUDA from .nn import modules +from . import cuda_setup if COMPILED_WITH_CUDA: from .optim import adam diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index bc11474..f5b97fd 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -2,7 +2,7 @@ import ctypes as ct import os from warnings import warn -from bitsandbytes.cuda_setup import evaluate_cuda_setup +from bitsandbytes.cuda_setup.main import evaluate_cuda_setup class CUDALibrary_Singleton(object): diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/bitsandbytes/cuda_setup/__init__.py diff --git a/bitsandbytes/cuda_setup/compute_capability.py b/bitsandbytes/cuda_setup/compute_capability.py new file mode 100644 index 0000000..19ceb3b --- /dev/null +++ b/bitsandbytes/cuda_setup/compute_capability.py @@ -0,0 +1,65 @@ +import ctypes +from dataclasses import dataclass, field + + +CUDA_SUCCESS = 0 + +@dataclass +class CudaLibVals: + # code bits taken from + # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 + + nGpus = ctypes.c_int() + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + device = ctypes.c_int() + error_str = ctypes.c_char_p() + cuda: ctypes.CDLL = field(init=False, repr=False) + ccs: List[str, ...] = field(init=False) + + def load_cuda_lib(self): + """ + 1. find libcuda.so library (GPU driver) (/usr/lib) + init_device -> init variables -> call function by reference + """ + libnames = ("libcuda.so") + for libname in libnames: + try: + self.cuda = ctypes.CDLL(libname) + except OSError: + continue + else: + break + else: + raise OSError("could not load any of: " + " ".join(libnames)) + + def check_cuda_result(self, result_val): + """ + 2. call extern C function to determine CC + (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) + """ + cls_fields: Tuple[Field, ...] = fields(self.__class__) + + if result_val != 0: + self.cuda.cuGetErrorString(result_val, ctypes.byref(self.error_str)) + print("Count not initialize CUDA - failure!") + raise Exception("CUDA exception!") + return result_val + + def __post_init__(self): + self.load_cuda_lib() + self.check_cuda_result(self.cuda.cuInit(0)) + self.check_cuda_result(self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus))) + tmp_ccs = [] + for gpu_index in range(self.nGpus.value): + check_cuda_result( + self.cuda, self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index) + ) + check_cuda_result( + self.cuda, + self.cuda.cuDeviceComputeCapability( + ctypes.byref(self.cc_major), ctypes.byref(self.cc_minor), self.device + ), + ) + tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}") + self.ccs = sorted(tmp_ccs, reverse=True) diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup/main.py index e68cd5e..6d70c92 100644 --- a/bitsandbytes/cuda_setup.py +++ b/bitsandbytes/cuda_setup/main.py @@ -1,6 +1,6 @@ """ extract factors the build is dependent on: -[X] compute capability +[X] compute capability [ ] TODO: Q - What if we have multiple GPUs of different makes? - CUDA version - Software: @@ -23,7 +23,7 @@ import os from pathlib import Path from typing import Set, Union -from .utils import print_err, warn_of_missing_prerequisite, execute_and_return +from ..utils import print_err, warn_of_missing_prerequisite, execute_and_return def check_cuda_result(cuda, result_val): |