summaryrefslogtreecommitdiff
path: root/tests/test_cuda_setup_evaluator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_cuda_setup_evaluator.py')
-rw-r--r--tests/test_cuda_setup_evaluator.py77
1 files changed, 39 insertions, 38 deletions
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py
index 72aa3c7..d45354f 100644
--- a/tests/test_cuda_setup_evaluator.py
+++ b/tests/test_cuda_setup_evaluator.py
@@ -1,37 +1,45 @@
-import pytest
import os
+from typing import List, NamedTuple
+
+import pytest
-from typing import List
+from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
+ get_cuda_runtime_lib_path, tokenize_paths)
-from bitsandbytes.cuda_setup import (
- CUDA_RUNTIME_LIB,
- get_cuda_runtime_lib_path,
- evaluate_cuda_setup,
- tokenize_paths,
-)
+class InputAndExpectedOutput(NamedTuple):
+ input: str
+ output: str
-HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [
+
+HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
- (f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"),
+ (
+ f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
+ f"dir/with/{CUDA_RUNTIME_LIB}",
+ ),
]
-@pytest.mark.parametrize(
- "test_input, expected",
- HAPPY_PATH__LD_LIB_TEST_PATHS
-)
+@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
+def happy_path_path_string(tmpdir, request):
+ for path in tokenize_paths(request.param):
+ test_dir.mkdir()
+ if CUDA_RUNTIME_LIB in path:
+ (test_input / CUDA_RUNTIME_LIB).touch()
+
+
+@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
def test_get_cuda_runtime_lib_path__happy_path(
- tmp_path, test_input: str, expected: str
+ tmp_path, test_input: str, expected: str
):
for path in tokenize_paths(test_input):
- assert False == tmp_path / test_input
- test_dir.mkdir()
- (test_input / CUDA_RUNTIME_LIB).touch()
+ path.mkdir()
+ (path / CUDA_RUNTIME_LIB).touch()
assert get_cuda_runtime_lib_path(test_input) == expected
@@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
(test_input / CUDA_RUNTIME_LIB).touch()
with pytest.raises(FileNotFoundError) as err_info:
get_cuda_runtime_lib_path(test_input)
- assert all(
- match in err_info
- for match in {"duplicate", CUDA_RUNTIME_LIB}
- )
+ assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})
def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
- existent_dir = tmp_path / 'a/b'
+ existent_dir = tmp_path / "a/b"
existent_dir.mkdir()
- non_existent_dir = tmp_path / 'c/d' # non-existent dir
+ non_existent_dir = tmp_path / "c/d" # non-existent dir
test_input = ":".join([str(existent_dir), str(non_existent_dir)])
get_cuda_runtime_lib_path(test_input)
std_err = capsys.readouterr().err
- assert all(
- match in std_err
- for match in {"WARNING", "non-existent"}
- )
+ assert all(match in std_err for match in {"WARNING", "non-existent"})
+
def test_full_system():
## this only tests the cuda version and not compute capability
- ld_path = os.environ['LD_LIBRARY_PATH']
- paths = ld_path.split(':')
- version = ''
+ ld_path = os.environ["LD_LIBRARY_PATH"]
+ paths = ld_path.split(":")
+ version = ""
for p in paths:
- if 'cuda' in p:
- idx = p.rfind('cuda-')
- version = p[idx+5:idx+5+4].replace('/', '')
+ if "cuda" in p:
+ idx = p.rfind("cuda-")
+ version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version)
break
binary_name = evaluate_cuda_setup()
- binary_name = binary_name.replace('libbitsandbytes_cuda', '')
- assert binary_name.startswith(str(version).replace('.', ''))
-
-
+ binary_name = binary_name.replace("libbitsandbytes_cuda", "")
+ assert binary_name.startswith(str(version).replace(".", ""))