diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/cuda_setup.py | 54 | ||||
-rw-r--r-- | bitsandbytes/utils.py | 3 |
2 files changed, 30 insertions, 27 deletions
diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py index 05d4c7f..59e90e4 100644 --- a/bitsandbytes/cuda_setup.py +++ b/bitsandbytes/cuda_setup.py @@ -19,11 +19,11 @@ evaluation: """ import ctypes -from os import environ as env +import os from pathlib import Path from typing import Set, Union -from .utils import print_err, warn_of_missing_prerequisite +from .utils import print_err, warn_of_missing_prerequisite, execute_and_return def check_cuda_result(cuda, result_val): @@ -88,22 +88,11 @@ def tokenize_paths(paths: str) -> Set[Path]: return {Path(ld_path) for ld_path in paths.split(":") if ld_path} -def get_cuda_runtime_lib_path( - # TODO: replace this with logic for all paths in env vars - LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH") -) -> Union[Path, None]: - """# TODO: add doc-string""" - - if not LD_LIBRARY_PATH: - warn_of_missing_prerequisite( - "LD_LIBRARY_PATH is completely missing from environment!" - ) - return None - - ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH) +def resolve_env_variable(env_var): + paths: Set[Path] = tokenize_paths(env_var) non_existent_directories: Set[Path] = { - path for path in ld_library_paths if not path.exists() + path for path in paths if not path.exists() } if non_existent_directories: @@ -114,7 +103,7 @@ def get_cuda_runtime_lib_path( cuda_runtime_libs: Set[Path] = { path / CUDA_RUNTIME_LIB - for path in ld_library_paths + for path in paths if (path / CUDA_RUNTIME_LIB).is_file() } - non_existent_directories @@ -123,19 +112,35 @@ def get_cuda_runtime_lib_path( f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." ) raise FileNotFoundError(err_msg) + elif len(cuda_runtime_libs) == 0: return None + else: return next(iter(cuda_runtime_libs)) # for now just return the first + +def get_cuda_runtime_lib_path() -> Union[Path, None]: + """# TODO: add doc-string""" + + cuda_runtime_libs = [] + if 'CONDA_PREFIX' in os.environ: + lib_conda_path = f'{os.environ["CONDA_PREFIX"]}/lib/' + print(lib_conda_path) + cuda_runtime_libs.append(resolve_env_variable(lib_conda_path)) + + if len(cuda_runtime_libs) == 1: return cuda_runtime_libs[0] + + for var in os.environ: + cuda_runtime_libs.append(resolve_env_variable(var)) - elif len(cuda_runtime_libs) < 1: + if len(cuda_runtime_libs) < 1: err_msg = ( f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." ) raise FileNotFoundError(err_msg) - single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) - return single_cuda_runtime_lib_dir + return cuda_runtime_libs.pop() def evaluate_cuda_setup(): cuda_path = get_cuda_runtime_lib_path() + print(f'CUDA SETUP: CUDA path found: {cuda_path}') cc = get_compute_capability() binary_name = "libbitsandbytes_cpu.so" @@ -152,13 +157,10 @@ def evaluate_cuda_setup(): # (2) Multiple CUDA versions installed cuda_home = str(Path(cuda_path).parent.parent) - ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version") - cuda_version = ( - ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "") - ) - major, minor, revision = cuda_version.split(".") + ls_output, err = execute_and_return(f"ls -l {cuda_path}") + major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.') cuda_version_string = f"{major}{minor}" - binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so' + binary_name = f'libbitsandbytes_cuda{cuda_version_string}{("" if has_cublaslt else "_nocublaslt")}.so' return binary_name diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 8a9fc0e..e1d9460 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -2,6 +2,7 @@ import sys import shlex import subprocess +from typing import Tuple def execute_and_return(command_string: str) -> Tuple[str, str]: def _decode(subprocess_err_out_tuple): @@ -19,7 +20,7 @@ def execute_and_return(command_string: str) -> Tuple[str, str]: ).communicate() ) - std_out, std_err = execute_and_return_decoded_std_streams() + std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err |