summaryrefslogtreecommitdiff
path: root/bitsandbytes/cextension.py
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-02 21:26:50 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-02 21:26:50 -0700
commit59a615b3869eb8488a748e2aa51224a5e3d366bb (patch)
tree5f348d63ba837d08bbc5df703a748c0ae6e34ddd /bitsandbytes/cextension.py
parent3809236428e704f9a7e22232701a651aafa5ca1b (diff)
factored cuda_setup.main out into smaller modules and functions
Diffstat (limited to 'bitsandbytes/cextension.py')
-rw-r--r--bitsandbytes/cextension.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index f5b97fd..66c79d8 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -1,8 +1,8 @@
import ctypes as ct
-import os
+from pathlib import Path
from warnings import warn
-from bitsandbytes.cuda_setup.main import evaluate_cuda_setup
+from .cuda_setup.main import evaluate_cuda_setup
class CUDALibrary_Singleton(object):
@@ -12,18 +12,17 @@ class CUDALibrary_Singleton(object):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
- self.context = {}
binary_name = evaluate_cuda_setup()
- if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"):
+ package_dir = Path(__file__).parent
+ binary_path = package_dir / binary_name
+
+ if not binary_path.exists():
print(f"TODO: compile library for specific version: {binary_name}")
- print("defaulting to libbitsandbytes.so")
- self.lib = ct.cdll.LoadLibrary(
- os.path.dirname(__file__) + "/libbitsandbytes.so"
- )
+ legacy_binary_name = "libbitsandbytes.so"
+ print(f"Defaulting to {legacy_binary_name}...")
+ self.lib = ct.cdll.LoadLibrary(package_dir / legacy_binary_name)
else:
- self.lib = ct.cdll.LoadLibrary(
- os.path.dirname(__file__) + f"/{binary_name}"
- )
+ self.lib = ct.cdll.LoadLibrary(package_dir / binary_name)
@classmethod
def get_instance(cls):