diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-01 19:22:41 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-01 19:22:41 -0700 |
commit | 8bf3e9faab6dfb04d676a5ea413530cdee09744c (patch) | |
tree | a465e10c7c040f2bdc56dfe5afd681538ef8b2af /tests | |
parent | c4fe6c69a33bacd292dfab87dc91dad32ba5fcf4 (diff) |
Added full env variable search; CONDA_PREFIX priority.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_cuda_setup_evaluator.py | 34 |
1 files changed, 22 insertions, 12 deletions
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 5a58be4..5da190d 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,7 +1,8 @@ import os -from typing import List, NamedTuple - import pytest +import bitsandbytes as bnb + +from typing import List, NamedTuple from bitsandbytes.cuda_setup import ( CUDA_RUNTIME_LIB, @@ -91,16 +92,25 @@ def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): def test_full_system(): ## this only tests the cuda version and not compute capability - ld_path = os.environ["LD_LIBRARY_PATH"] - paths = ld_path.split(":") - version = "" - for p in paths: - if "cuda" in p: - idx = p.rfind("cuda-") - version = p[idx + 5 : idx + 5 + 4].replace("/", "") - version = float(version) - break - + version = '' + if 'CONDA_PREFIX' in os.environ: + ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so') + major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.') + version = float(f'{major}.{minor}') + + + if version == '' and 'LD_LIBRARY_PATH': + ld_path = os.environ["LD_LIBRARY_PATH"] + paths = ld_path.split(":") + version = "" + for p in paths: + if "cuda" in p: + idx = p.rfind("cuda-") + version = p[idx + 5 : idx + 5 + 4].replace("/", "") + version = float(version) + break + + assert version > 0 binary_name = evaluate_cuda_setup() binary_name = binary_name.replace("libbitsandbytes_cuda", "") assert binary_name.startswith(str(version).replace(".", "")) |