summaryrefslogtreecommitdiff
path: root/bitsandbytes/cuda_setup
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-04 09:16:00 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-04 09:16:00 -0700
commit8f84674d6774c351b1e69dfede2c11a370e334b9 (patch)
tree2312dde5180f471cd94b9ade5941c0e374386615 /bitsandbytes/cuda_setup
parent758c7175a24df307c40b743b1def8b4c34f68674 (diff)
Fixed bugs in cuda setup.
Diffstat (limited to 'bitsandbytes/cuda_setup')
-rw-r--r--bitsandbytes/cuda_setup/__init__.py2
-rw-r--r--bitsandbytes/cuda_setup/main.py14
2 files changed, 11 insertions, 5 deletions
diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py
index e69de29..d8ebba8 100644
--- a/bitsandbytes/cuda_setup/__init__.py
+++ b/bitsandbytes/cuda_setup/__init__.py
@@ -0,0 +1,2 @@
+from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
+from .main import evaluate_cuda_setup
diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py
index e96ac70..1e52f89 100644
--- a/bitsandbytes/cuda_setup/main.py
+++ b/bitsandbytes/cuda_setup/main.py
@@ -47,6 +47,7 @@ def get_compute_capabilities():
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
+ print('ERROR: libcuda.so not found!')
return None
nGpus = ctypes.c_int()
@@ -70,7 +71,7 @@ def get_compute_capabilities():
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")
- return ccs.sort()
+ return ccs
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
@@ -80,7 +81,8 @@ def get_compute_capability():
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
- if ccs := get_compute_capabilities() is not None:
+ ccs = get_compute_capabilities()
+ if ccs is not None:
# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
@@ -92,8 +94,7 @@ def evaluate_cuda_setup():
cc = get_compute_capability()
binary_name = "libbitsandbytes_cpu.so"
- # FIXME: has_gpu is still unused
- if not (has_gpu := bool(cc)):
+ if cc == '':
print(
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
)
@@ -115,6 +116,7 @@ def evaluate_cuda_setup():
ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")
)
cuda_version_string = f"{major}{minor}"
+ print(f'CUDA_SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
@@ -122,6 +124,8 @@ def evaluate_cuda_setup():
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
else:
- return f"{bin_base_name}_nocublaslt.so"
+ return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
+
+ binary_name = get_binary_name()
return binary_name