summaryrefslogtreecommitdiff
path: root/bitsandbytes/cuda_setup/compute_capability.py
blob: 19ceb3b8abdcc26e210564927914acc22539906e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import ctypes
from dataclasses import dataclass, field


CUDA_SUCCESS = 0

@dataclass
class CudaLibVals:
    # code bits taken from 
    # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549

    nGpus = ctypes.c_int()
    cc_major = ctypes.c_int()
    cc_minor = ctypes.c_int()
    device = ctypes.c_int()
    error_str = ctypes.c_char_p()
    cuda: ctypes.CDLL = field(init=False, repr=False)
    ccs: List[str, ...] = field(init=False)

    def load_cuda_lib(self):
        """
        1. find libcuda.so library (GPU driver) (/usr/lib)
           init_device -> init variables -> call function by reference
        """
        libnames = ("libcuda.so")
        for libname in libnames:
            try:
                self.cuda = ctypes.CDLL(libname)
            except OSError:
                continue
            else:
                break
        else:
            raise OSError("could not load any of: " + " ".join(libnames))

    def check_cuda_result(self, result_val):
        """
        2. call extern C function to determine CC 
           (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
        """
        cls_fields: Tuple[Field, ...] = fields(self.__class__)

        if result_val != 0:
            self.cuda.cuGetErrorString(result_val, ctypes.byref(self.error_str))
            print("Count not initialize CUDA - failure!")
            raise Exception("CUDA exception!")
        return result_val

    def __post_init__(self):
        self.load_cuda_lib()
        self.check_cuda_result(self.cuda.cuInit(0))
        self.check_cuda_result(self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus)))
        tmp_ccs = []
        for gpu_index in range(self.nGpus.value):
            check_cuda_result(
                self.cuda, self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index)
            )
            check_cuda_result(
                self.cuda,
                self.cuda.cuDeviceComputeCapability(
                    ctypes.byref(self.cc_major), ctypes.byref(self.cc_minor), self.device
                ),
            )
            tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
        self.ccs = sorted(tmp_ccs, reverse=True)