summaryrefslogtreecommitdiff
path: root/tests/test_cuda_setup_evaluator.py
blob: 5da190d3a7ecd16ebc55c76e28394db266f7c35d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import pytest
import bitsandbytes as bnb

from typing import List, NamedTuple

from bitsandbytes.cuda_setup import (
    CUDA_RUNTIME_LIB,
    evaluate_cuda_setup,
    get_cuda_runtime_lib_path,
    tokenize_paths,
)


class InputAndExpectedOutput(NamedTuple):
    input: str
    output: str


HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
    (
        f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
        f"dir/with/{CUDA_RUNTIME_LIB}",
    ),
    (
        f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
        f"dir/with/{CUDA_RUNTIME_LIB}",
    ),
    (
        f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
        f"dir/with/{CUDA_RUNTIME_LIB}",
    ),
    (
        f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
        f"dir/with/{CUDA_RUNTIME_LIB}",
    ),
    (
        f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
        f"dir/with/{CUDA_RUNTIME_LIB}",
    ),
    (
        f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
        f"dir/with/{CUDA_RUNTIME_LIB}",
    ),
]


@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
def happy_path_path_string(tmpdir, request):
    for path in tokenize_paths(request.param):
        test_dir.mkdir()
        if CUDA_RUNTIME_LIB in path:
            (test_input / CUDA_RUNTIME_LIB).touch()


@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
def test_get_cuda_runtime_lib_path__happy_path(
    tmp_path, test_input: str, expected: str
):
    for path in tokenize_paths(test_input):
        path.mkdir()
        (path / CUDA_RUNTIME_LIB).touch()
    assert get_cuda_runtime_lib_path(test_input) == expected


UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
    f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
    f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
]


@pytest.mark.parametrize("test_input", UNHAPPY_PATH__LD_LIB_TEST_PATHS)
def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
    test_input = tmp_path / test_input
    (test_input / CUDA_RUNTIME_LIB).touch()
    with pytest.raises(FileNotFoundError) as err_info:
        get_cuda_runtime_lib_path(test_input)
    assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})


def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
    existent_dir = tmp_path / "a/b"
    existent_dir.mkdir()
    non_existent_dir = tmp_path / "c/d"  # non-existent dir
    test_input = ":".join([str(existent_dir), str(non_existent_dir)])

    get_cuda_runtime_lib_path(test_input)
    std_err = capsys.readouterr().err

    assert all(match in std_err for match in {"WARNING", "non-existent"})


def test_full_system():
    ## this only tests the cuda version and not compute capability
    version = ''
    if 'CONDA_PREFIX' in os.environ:
        ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
        major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
        version = float(f'{major}.{minor}')


    if version == '' and 'LD_LIBRARY_PATH':
        ld_path = os.environ["LD_LIBRARY_PATH"]
        paths = ld_path.split(":")
        version = ""
        for p in paths:
            if "cuda" in p:
                idx = p.rfind("cuda-")
                version = p[idx + 5 : idx + 5 + 4].replace("/", "")
                version = float(version)
                break

    assert version > 0
    binary_name = evaluate_cuda_setup()
    binary_name = binary_name.replace("libbitsandbytes_cuda", "")
    assert binary_name.startswith(str(version).replace(".", ""))