From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- tests/test_cuda_setup_evaluator.py | 77 +++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 38 deletions(-) (limited to 'tests/test_cuda_setup_evaluator.py') diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 72aa3c7..d45354f 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,37 +1,45 @@ -import pytest import os +from typing import List, NamedTuple + +import pytest -from typing import List +from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup, + get_cuda_runtime_lib_path, tokenize_paths) -from bitsandbytes.cuda_setup import ( - CUDA_RUNTIME_LIB, - get_cuda_runtime_lib_path, - evaluate_cuda_setup, - tokenize_paths, -) +class InputAndExpectedOutput(NamedTuple): + input: str + output: str -HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [ + +HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [ (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), - (f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"), + ( + f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), ] -@pytest.mark.parametrize( - "test_input, expected", - HAPPY_PATH__LD_LIB_TEST_PATHS -) +@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS) +def happy_path_path_string(tmpdir, request): + for path in tokenize_paths(request.param): + test_dir.mkdir() + if CUDA_RUNTIME_LIB in path: + (test_input / CUDA_RUNTIME_LIB).touch() + + +@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS) def test_get_cuda_runtime_lib_path__happy_path( - tmp_path, test_input: str, expected: str + tmp_path, test_input: str, expected: str ): for path in tokenize_paths(test_input): - assert False == tmp_path / test_input - test_dir.mkdir() - (test_input / CUDA_RUNTIME_LIB).touch() + path.mkdir() + (path / CUDA_RUNTIME_LIB).touch() assert get_cuda_runtime_lib_path(test_input) == expected @@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str): (test_input / CUDA_RUNTIME_LIB).touch() with pytest.raises(FileNotFoundError) as err_info: get_cuda_runtime_lib_path(test_input) - assert all( - match in err_info - for match in {"duplicate", CUDA_RUNTIME_LIB} - ) + assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB}) def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): - existent_dir = tmp_path / 'a/b' + existent_dir = tmp_path / "a/b" existent_dir.mkdir() - non_existent_dir = tmp_path / 'c/d' # non-existent dir + non_existent_dir = tmp_path / "c/d" # non-existent dir test_input = ":".join([str(existent_dir), str(non_existent_dir)]) get_cuda_runtime_lib_path(test_input) std_err = capsys.readouterr().err - assert all( - match in std_err - for match in {"WARNING", "non-existent"} - ) + assert all(match in std_err for match in {"WARNING", "non-existent"}) + def test_full_system(): ## this only tests the cuda version and not compute capability - ld_path = os.environ['LD_LIBRARY_PATH'] - paths = ld_path.split(':') - version = '' + ld_path = os.environ["LD_LIBRARY_PATH"] + paths = ld_path.split(":") + version = "" for p in paths: - if 'cuda' in p: - idx = p.rfind('cuda-') - version = p[idx+5:idx+5+4].replace('/', '') + if "cuda" in p: + idx = p.rfind("cuda-") + version = p[idx + 5 : idx + 5 + 4].replace("/", "") version = float(version) break binary_name = evaluate_cuda_setup() - binary_name = binary_name.replace('libbitsandbytes_cuda', '') - assert binary_name.startswith(str(version).replace('.', '')) - - + binary_name = binary_name.replace("libbitsandbytes_cuda", "") + assert binary_name.startswith(str(version).replace(".", "")) -- cgit v1.2.3