diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_cuda_setup_evaluator.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py new file mode 100644 index 0000000..96ee6c5 --- /dev/null +++ b/tests/test_cuda_setup_evaluator.py @@ -0,0 +1,66 @@ +import pytest + +from typing import List + +from bitsandbytes.cuda_setup import ( + CUDA_RUNTIME_LIB, + get_cuda_runtime_lib_path, + evaluate_cuda_setup, + tokenize_paths, +) + + +HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [ + (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), + (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), + (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), + (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), + (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), +] + + +@pytest.mark.parametrize( + "test_input, expected", + HAPPY_PATH__LD_LIB_TEST_PATHS +) +def test_get_cuda_runtime_lib_path__happy_path( + tmp_path, test_input: str, expected: str +): + for path in tokenize_paths(test_input): + assert False == tmp_path / test_input + test_dir.mkdir() + (test_input / CUDA_RUNTIME_LIB).touch() + assert get_cuda_runtime_lib_path(test_input) == expected + + +UNHAPPY_PATH__LD_LIB_TEST_PATHS = [ + f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}", + f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}", +] + + +@pytest.mark.parametrize("test_input", UNHAPPY_PATH__LD_LIB_TEST_PATHS) +def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str): + test_input = tmp_path / test_input + (test_input / CUDA_RUNTIME_LIB).touch() + with pytest.raises(FileNotFoundError) as err_info: + get_cuda_runtime_lib_path(test_input) + assert all( + match in err_info + for match in {"duplicate", CUDA_RUNTIME_LIB} + ) + + +def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): + existent_dir = tmp_path / 'a/b' + existent_dir.mkdir() + non_existent_dir = tmp_path / 'c/d' # non-existent dir + test_input = ":".join([str(existent_dir), str(non_existent_dir)]) + + get_cuda_runtime_lib_path(test_input) + std_err = capsys.readouterr().err + + assert all( + match in std_err + for match in {"WARNING", "non-existent"} + ) |