summaryrefslogtreecommitdiff
path: root/bitsandbytes/cuda_setup.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-01 19:22:41 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-01 19:22:41 -0700
commit8bf3e9faab6dfb04d676a5ea413530cdee09744c (patch)
treea465e10c7c040f2bdc56dfe5afd681538ef8b2af /bitsandbytes/cuda_setup.py
parentc4fe6c69a33bacd292dfab87dc91dad32ba5fcf4 (diff)
Added full env variable search; CONDA_PREFIX priority.
Diffstat (limited to 'bitsandbytes/cuda_setup.py')
-rw-r--r--bitsandbytes/cuda_setup.py54
1 files changed, 28 insertions, 26 deletions
diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py
index 05d4c7f..59e90e4 100644
--- a/bitsandbytes/cuda_setup.py
+++ b/bitsandbytes/cuda_setup.py
@@ -19,11 +19,11 @@ evaluation:
"""
import ctypes
-from os import environ as env
+import os
from pathlib import Path
from typing import Set, Union
-from .utils import print_err, warn_of_missing_prerequisite
+from .utils import print_err, warn_of_missing_prerequisite, execute_and_return
def check_cuda_result(cuda, result_val):
@@ -88,22 +88,11 @@ def tokenize_paths(paths: str) -> Set[Path]:
return {Path(ld_path) for ld_path in paths.split(":") if ld_path}
-def get_cuda_runtime_lib_path(
- # TODO: replace this with logic for all paths in env vars
- LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH")
-) -> Union[Path, None]:
- """# TODO: add doc-string"""
-
- if not LD_LIBRARY_PATH:
- warn_of_missing_prerequisite(
- "LD_LIBRARY_PATH is completely missing from environment!"
- )
- return None
-
- ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH)
+def resolve_env_variable(env_var):
+ paths: Set[Path] = tokenize_paths(env_var)
non_existent_directories: Set[Path] = {
- path for path in ld_library_paths if not path.exists()
+ path for path in paths if not path.exists()
}
if non_existent_directories:
@@ -114,7 +103,7 @@ def get_cuda_runtime_lib_path(
cuda_runtime_libs: Set[Path] = {
path / CUDA_RUNTIME_LIB
- for path in ld_library_paths
+ for path in paths
if (path / CUDA_RUNTIME_LIB).is_file()
} - non_existent_directories
@@ -123,19 +112,35 @@ def get_cuda_runtime_lib_path(
f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
)
raise FileNotFoundError(err_msg)
+ elif len(cuda_runtime_libs) == 0: return None
+ else: return next(iter(cuda_runtime_libs)) # for now just return the first
+
+def get_cuda_runtime_lib_path() -> Union[Path, None]:
+ """# TODO: add doc-string"""
+
+ cuda_runtime_libs = []
+ if 'CONDA_PREFIX' in os.environ:
+ lib_conda_path = f'{os.environ["CONDA_PREFIX"]}/lib/'
+ print(lib_conda_path)
+ cuda_runtime_libs.append(resolve_env_variable(lib_conda_path))
+
+ if len(cuda_runtime_libs) == 1: return cuda_runtime_libs[0]
+
+ for var in os.environ:
+ cuda_runtime_libs.append(resolve_env_variable(var))
- elif len(cuda_runtime_libs) < 1:
+ if len(cuda_runtime_libs) < 1:
err_msg = (
f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
)
raise FileNotFoundError(err_msg)
- single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
- return single_cuda_runtime_lib_dir
+ return cuda_runtime_libs.pop()
def evaluate_cuda_setup():
cuda_path = get_cuda_runtime_lib_path()
+ print(f'CUDA SETUP: CUDA path found: {cuda_path}')
cc = get_compute_capability()
binary_name = "libbitsandbytes_cpu.so"
@@ -152,13 +157,10 @@ def evaluate_cuda_setup():
# (2) Multiple CUDA versions installed
cuda_home = str(Path(cuda_path).parent.parent)
- ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version")
- cuda_version = (
- ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "")
- )
- major, minor, revision = cuda_version.split(".")
+ ls_output, err = execute_and_return(f"ls -l {cuda_path}")
+ major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
cuda_version_string = f"{major}{minor}"
- binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so'
+ binary_name = f'libbitsandbytes_cuda{cuda_version_string}{("" if has_cublaslt else "_nocublaslt")}.so'
return binary_name