From 96bc209baf55f2e05e649e555c2de5fc478c24dc Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Tue, 2 Aug 2022 21:27:36 -0700 Subject: tentative refactoring of the compute capabilities code --- bitsandbytes/cuda_setup/compute_capability.py | 56 +++++++++++++++++---------- 1 file changed, 35 insertions(+), 21 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/cuda_setup/compute_capability.py b/bitsandbytes/cuda_setup/compute_capability.py index 19ceb3b..7a3f463 100644 --- a/bitsandbytes/cuda_setup/compute_capability.py +++ b/bitsandbytes/cuda_setup/compute_capability.py @@ -2,27 +2,28 @@ import ctypes from dataclasses import dataclass, field -CUDA_SUCCESS = 0 - @dataclass class CudaLibVals: - # code bits taken from + # code bits taken from # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 - nGpus = ctypes.c_int() - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() - device = ctypes.c_int() - error_str = ctypes.c_char_p() + nGpus: ctypes.c_int = field(default=ctypes.c_int()) + cc_major: ctypes.c_int = field(default=ctypes.c_int()) + cc_minor: ctypes.c_int = field(default=ctypes.c_int()) + device: ctypes.c_int = field(default=ctypes.c_int()) + error_str: ctypes.c_char_p = field(default=ctypes.c_char_p()) cuda: ctypes.CDLL = field(init=False, repr=False) ccs: List[str, ...] = field(init=False) - def load_cuda_lib(self): + def _initialize_driver_API(self): + self.check_cuda_result(self.cuda.cuInit(0)) + + def _load_cuda_lib(self): """ 1. find libcuda.so library (GPU driver) (/usr/lib) init_device -> init variables -> call function by reference """ - libnames = ("libcuda.so") + libnames = "libcuda.so" for libname in libnames: try: self.cuda = ctypes.CDLL(libname) @@ -33,32 +34,45 @@ class CudaLibVals: else: raise OSError("could not load any of: " + " ".join(libnames)) - def check_cuda_result(self, result_val): + def call_cuda_func(self, function_obj, **kwargs): + CUDA_SUCCESS = 0 # constant taken from cuda.h + pass + # if (CUDA_SUCCESS := function_obj( + + def _error_handle(cuda_lib_call_return_value): """ - 2. call extern C function to determine CC + 2. call extern C function to determine CC (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) """ - cls_fields: Tuple[Field, ...] = fields(self.__class__) + CUDA_SUCCESS = 0 # constant taken from cuda.h - if result_val != 0: - self.cuda.cuGetErrorString(result_val, ctypes.byref(self.error_str)) + if cuda_lib_call_return_value != CUDA_SUCCESS: + self.cuda.cuGetErrorString( + cuda_lib_call_return_value, + ctypes.byref(self.error_str), + ) print("Count not initialize CUDA - failure!") raise Exception("CUDA exception!") - return result_val + return cuda_lib_call_return_value def __post_init__(self): - self.load_cuda_lib() - self.check_cuda_result(self.cuda.cuInit(0)) - self.check_cuda_result(self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus))) + self._load_cuda_lib() + self._initialize_driver_API() + self.check_cuda_result( + self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus)) + ) tmp_ccs = [] for gpu_index in range(self.nGpus.value): check_cuda_result( - self.cuda, self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index) + self.cuda, + self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index), ) check_cuda_result( self.cuda, self.cuda.cuDeviceComputeCapability( - ctypes.byref(self.cc_major), ctypes.byref(self.cc_minor), self.device + ctypes.byref(self.cc_major), + ctypes.byref(self.cc_minor), + self.device, ), ) tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}") -- cgit v1.2.3