summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-08-01 19:22:41 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-08-01 19:22:41 -0700
commit8bf3e9faab6dfb04d676a5ea413530cdee09744c (patch)
treea465e10c7c040f2bdc56dfe5afd681538ef8b2af /tests
parentc4fe6c69a33bacd292dfab87dc91dad32ba5fcf4 (diff)
Added full env variable search; CONDA_PREFIX priority.
Diffstat (limited to 'tests')
-rw-r--r--tests/test_cuda_setup_evaluator.py34
1 files changed, 22 insertions, 12 deletions
diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py
index 5a58be4..5da190d 100644
--- a/tests/test_cuda_setup_evaluator.py
+++ b/tests/test_cuda_setup_evaluator.py
@@ -1,7 +1,8 @@
import os
-from typing import List, NamedTuple
-
import pytest
+import bitsandbytes as bnb
+
+from typing import List, NamedTuple
from bitsandbytes.cuda_setup import (
CUDA_RUNTIME_LIB,
@@ -91,16 +92,25 @@ def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
def test_full_system():
## this only tests the cuda version and not compute capability
- ld_path = os.environ["LD_LIBRARY_PATH"]
- paths = ld_path.split(":")
- version = ""
- for p in paths:
- if "cuda" in p:
- idx = p.rfind("cuda-")
- version = p[idx + 5 : idx + 5 + 4].replace("/", "")
- version = float(version)
- break
-
+ version = ''
+ if 'CONDA_PREFIX' in os.environ:
+ ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
+ major, minor, revision = ls_output.split(' ')[-1].replace('libcudart.so.', '').split('.')
+ version = float(f'{major}.{minor}')
+
+
+ if version == '' and 'LD_LIBRARY_PATH':
+ ld_path = os.environ["LD_LIBRARY_PATH"]
+ paths = ld_path.split(":")
+ version = ""
+ for p in paths:
+ if "cuda" in p:
+ idx = p.rfind("cuda-")
+ version = p[idx + 5 : idx + 5 + 4].replace("/", "")
+ version = float(version)
+ break
+
+ assert version > 0
binary_name = evaluate_cuda_setup()
binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace(".", ""))