summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-10-24 11:54:25 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-10-24 11:54:25 -0700
commitdf86625a9399d16d6fb2e3bab6bb7bcc729f3b7d (patch)
tree34278a2cfd443d8e6f62aaba0f7a469db2807571 /bitsandbytes
parentb844e104b79ddc06161ff975aa93ffa9a7ec4801 (diff)
Isolated CUDASetup logging; all tests green.
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/cextension.py46
-rw-r--r--bitsandbytes/cuda_setup/main.py40
-rw-r--r--bitsandbytes/cuda_setup/paths.py27
-rw-r--r--bitsandbytes/nn/__init__.py2
-rw-r--r--bitsandbytes/nn/modules.py44
5 files changed, 61 insertions, 98 deletions
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index af23c8f..abb3054 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -2,33 +2,49 @@ import ctypes as ct
from pathlib import Path
from warnings import warn
-from .cuda_setup.main import evaluate_cuda_setup
-class CUDALibrary_Singleton(object):
+class CUDASetup(object):
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
+ self.cuda_setup_log = []
+
+ from .cuda_setup.main import evaluate_cuda_setup
binary_name = evaluate_cuda_setup()
package_dir = Path(__file__).parent
binary_path = package_dir / binary_name
- if not binary_path.exists():
- print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
- legacy_binary_name = "libbitsandbytes.so"
- print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
- binary_path = package_dir / legacy_binary_name
+ try:
if not binary_path.exists():
- print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
- print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
- raise Exception('CUDA SETUP: Setup Failed!')
- self.lib = ct.cdll.LoadLibrary(binary_path)
- else:
- print(f"CUDA SETUP: Loading binary {binary_path}...")
- self.lib = ct.cdll.LoadLibrary(binary_path)
+ self.add_log_entry(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
+ legacy_binary_name = "libbitsandbytes.so"
+ self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
+ binary_path = package_dir / legacy_binary_name
+ if not binary_path.exists():
+ self.add_log_entry('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
+ self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
+ self.print_log_stack()
+ raise Exception('CUDA SETUP: Setup Failed!')
+ self.lib = ct.cdll.LoadLibrary(binary_path)
+ else:
+ self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
+ self.lib = ct.cdll.LoadLibrary(binary_path)
+ except:
+ self.print_log_stack()
+
+ def add_log_entry(self, msg, is_warning=False):
+ self.cuda_setup_log.append((msg, is_warning))
+
+ def print_log_stack(self):
+ for msg, is_warning in self.cuda_setup_log:
+ if is_warning:
+ warn(msg)
+ else:
+ print(msg)
@classmethod
def get_instance(cls):
@@ -38,7 +54,7 @@ class CUDALibrary_Singleton(object):
return cls._instance
-lib = CUDALibrary_Singleton.get_instance().lib
+lib = CUDASetup.get_instance().lib
try:
lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p
diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py
index f11b430..f8f35f0 100644
--- a/bitsandbytes/cuda_setup/main.py
+++ b/bitsandbytes/cuda_setup/main.py
@@ -19,6 +19,7 @@ evaluation:
import ctypes
from .paths import determine_cuda_runtime_lib_path
+from bitsandbytes.cextension import CUDASetup
def check_cuda_result(cuda, result_val):
@@ -26,15 +27,14 @@ def check_cuda_result(cuda, result_val):
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
- print(f"CUDA exception! Error code: {error_str.value.decode()}")
+ CUDASetup.get_instance.add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
- # TODO: shouldn't we error or at least warn here?
- print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
+ CUDASetup.get_instance.add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
return None
version = ctypes.c_int()
@@ -44,7 +44,7 @@ def get_cuda_version(cuda, cudart_path):
minor = (version-(major*1000))//10
if major < 11:
- print('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
+ CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}'
@@ -54,8 +54,7 @@ def get_cuda_lib_handle():
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
- # TODO: shouldn't we error or at least warn here?
- print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
+ CUDA_RUNTIME_LIB.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
@@ -110,34 +109,33 @@ def get_compute_capability(cuda):
def evaluate_cuda_setup():
- print('')
- print('='*35 + 'BUG REPORT' + '='*35)
- print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
- print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
- print('='*80)
- binary_name = "libbitsandbytes_cpu.so"
+ # we remove this for now and see how things go
+ #print('')
+ #print('='*35 + 'BUG REPORT' + '='*35)
+ #print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
+ #print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
+ #print('='*80)
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
+ binary_name = "libbitsandbytes_cpu.so"
+
+ cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
- print(
- "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
- )
+ cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
return binary_name
- print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
+ cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}"))
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
- print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
+ cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path)
if cc == '':
- print(
- "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
- )
+ cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library...", is_warning=True)
return binary_name
# 7.5 is the minimum CC vor cublaslt
@@ -149,7 +147,7 @@ def evaluate_cuda_setup():
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
- print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
+ cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
diff --git a/bitsandbytes/cuda_setup/paths.py b/bitsandbytes/cuda_setup/paths.py
index ba3f97f..3223359 100644
--- a/bitsandbytes/cuda_setup/paths.py
+++ b/bitsandbytes/cuda_setup/paths.py
@@ -1,7 +1,7 @@
import errno
from pathlib import Path
from typing import Set, Union
-from warnings import warn
+from bitsandbytes.cextension import CUDASetup
from .env_vars import get_potentially_lib_path_containing_env_vars
@@ -24,10 +24,8 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
non_existent_directories: Set[Path] = candidate_paths - existent_directories
if non_existent_directories:
- warn(
- "WARNING: The following directories listed in your path were found to "
- f"be non-existent: {non_existent_directories}"
- )
+ CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to "
+ f"be non-existent: {non_existent_directories}", is_warning=True)
return existent_directories
@@ -62,9 +60,8 @@ def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
"Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one "
- f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env."
- )
- warn(warning_msg)
+ f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.")
+ CUDASetup.get_instance.add_log_entry(warning_msg, is_warning=True)
def determine_cuda_runtime_lib_path() -> Union[Path, None]:
@@ -90,10 +87,8 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
if conda_cuda_libs:
return next(iter(conda_cuda_libs))
- warn(
- f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
- f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
- )
+ CUDASetup.get_instance.add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
+ f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
if "LD_LIBRARY_PATH" in candidate_env_vars:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
@@ -102,10 +97,8 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
return next(iter(lib_ld_cuda_libs))
warn_in_case_of_duplicates(lib_ld_cuda_libs)
- warn(
- f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
- f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...'
- )
+ CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
+ f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
remaining_candidate_env_vars = {
env_var: value for env_var, value in candidate_env_vars.items()
@@ -117,7 +110,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0:
- print('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
+ CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs)
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 98d4aa0..edc595a 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding
+from .modules import Int8Params, Linear8bitLt, StableEmbedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 9250fec..4f82cdc 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -271,47 +271,3 @@ class Linear8bitLt(nn.Linear):
del self.state.CxB
return out
-
-
-class Linear8bit(nn.Linear):
- def __init__(
- self,
- input_features,
- output_features,
- bias=True,
- quant_type="vector",
- index=None,
- args=None,
- sparse_decomp=False,
- ):
- super(Linear8bit, self).__init__(input_features, output_features, bias)
- self.quant_type = quant_type
- self.index = index
- self.args = args
- self.iter = 0
-
- def forward(self, x):
- self.iter += 1
- if self.iter % self.args.clip_freq == 0:
- with torch.no_grad():
- maxval, maxidx = torch.topk(
- torch.abs(self.weight.flatten()), k=self.args.clip_idx
- )
- if not dist.is_initialized() or dist.get_rank() == 0:
- print("clip", maxval[-1].item())
- self.weight.clip_(-maxval[-1], maxval[-1])
-
- if self.args is not None:
- out = bnb.nn.functional.sparse_decomposed_linear8bit(
- x,
- self.weight,
- self.bias,
- qval=self.args.sparse_decomp_val,
- quant_type=self.args.quant_type,
- )
- else:
- out = bnb.nn.functional.linear8bit(
- x, self.weight, self.bias, quant_type=self.args.quant_type
- )
-
- return out