summaryrefslogtreecommitdiff
path: root/bitsandbytes/cuda_setup/paths.py
blob: c4a74656224f4f8f077a0a9aee6e5057332a772c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from pathlib import Path
from typing import Set, Union
from warnings import warn

from ..utils import print_stderr
from .env_vars import get_potentially_lib_path_containing_env_vars


CUDA_RUNTIME_LIB: str = "libcudart.so"


def purge_unwanted_semicolon(tentative_path: Path) -> Path:
    """
    Special function to handle the following exception:
    __LMOD_REF_COUNT_PATH=/sw/cuda/11.6.2/bin:2;/mmfs1/home/dettmers/git/sched/bin:1;/mmfs1/home/dettmers/data/anaconda3/bin:1;/mmfs1/home/dettmers/data/anaconda3/condabin:1;/mmfs1/home/dettmers/.local/bin:1;/mmfs1/home/dettmers/bin:1;/usr/local/bin:1;/usr/bin:1;/usr/local/sbin:1;/usr/sbin:1;/mmfs1/home/dettmers/.fzf/bin:1;/mmfs1/home/dettmers/data/local/cuda-11.4/bin:1
    """
    # if ';' in str(tentative_path):
    #     path_as_str, _ = str(tentative_path).split(';')
    pass


def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
    return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}


def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
    non_existent_directories: Set[Path] = {
        path for path in candidate_paths if not path.exists()
    }

    if non_existent_directories:
        print_stderr(
            "WARNING: The following directories listed in your path were found to "
            f"be non-existent: {non_existent_directories}"
        )

    return candidate_paths - non_existent_directories


def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
    return {
        path / CUDA_RUNTIME_LIB
        for path in candidate_paths
        if (path / CUDA_RUNTIME_LIB).is_file()
    }


def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
    """
    Searches a given environmental var for the CUDA runtime library,
    i.e. `libcudart.so`.
    """
    return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate))


def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
    return get_cuda_runtime_lib_paths(
        resolve_paths_list(paths_list_candidate)
    )


def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
    if len(results_paths) > 1:
        warning_msg = (
            f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
            "We'll flip a coin and try one of these, in order to fail forward.\n"
            "Either way, this might cause trouble in the future:\n"
            "If you get `CUDA error: invalid device function` errors, the above "
            "might be the cause and the solution is to make sure only one "
            f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env."
        )
        warn(warning_msg)


def determine_cuda_runtime_lib_path() -> Union[Path, None]:
    """
        Searches for a cuda installations, in the following order of priority:
            1. active conda env
            2. LD_LIBRARY_PATH
            3. any other env vars, while ignoring those that
                - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`)
                - don't contain the path separator `/`

        If multiple libraries are found in part 3, we optimistically try one,
        while giving a warning message.
    """
    candidate_env_vars = get_potentially_lib_path_containing_env_vars()

    if "CONDA_PREFIX" in candidate_env_vars:
        conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"

        conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
        warn_in_case_of_duplicates(conda_cuda_libs)

        if conda_cuda_libs:
            return next(iter(conda_cuda_libs))

        warn(
            f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
            f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
        )

    if "LD_LIBRARY_PATH" in candidate_env_vars:
        lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])

        if lib_ld_cuda_libs:
            return next(iter(lib_ld_cuda_libs))
        warn_in_case_of_duplicates(lib_ld_cuda_libs)

        warn(
            f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
            f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
        )

    remaining_candidate_env_vars = {
        env_var: value for env_var, value in candidate_env_vars.items()
        if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
    }

    cuda_runtime_libs = set()
    for env_var, value in remaining_candidate_env_vars:
        cuda_runtime_libs.update(find_cuda_lib_in(value))

    warn_in_case_of_duplicates(cuda_runtime_libs)

    return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else set()