diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/__init__.py | 28 | ||||
-rw-r--r-- | bitsandbytes/__main__.py | 96 | ||||
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 170 | ||||
-rw-r--r-- | bitsandbytes/cextension.py | 39 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/__init__.py | 0 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/compute_capability.py | 79 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/env_vars.py | 51 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/main.py | 127 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/paths.py | 126 | ||||
-rw-r--r-- | bitsandbytes/debug_cli.py | 26 | ||||
-rw-r--r-- | bitsandbytes/functional.py | 1303 | ||||
-rw-r--r-- | bitsandbytes/nn/__init__.py | 8 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 223 | ||||
-rw-r--r-- | bitsandbytes/optim/__init__.py | 6 | ||||
-rw-r--r-- | bitsandbytes/optim/adagrad.py | 126 | ||||
-rw-r--r-- | bitsandbytes/optim/adam.py | 198 | ||||
-rw-r--r-- | bitsandbytes/optim/adamw.py | 104 | ||||
-rw-r--r-- | bitsandbytes/optim/lamb.py | 117 | ||||
-rw-r--r-- | bitsandbytes/optim/lars.py | 183 | ||||
-rw-r--r-- | bitsandbytes/optim/optimizer.py | 622 | ||||
-rw-r--r-- | bitsandbytes/optim/rmsprop.py | 121 | ||||
-rw-r--r-- | bitsandbytes/optim/sgd.py | 109 | ||||
-rw-r--r-- | bitsandbytes/utils.py | 32 |
23 files changed, 2996 insertions, 898 deletions
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 3c3affa..7901f96 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -1,16 +1,26 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .nn import modules -from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState +from .autograd._functions import ( + MatmulLtState, + bmm_cublas, + matmul, + matmul_cublas, + mm_cublas, +) from .cextension import COMPILED_WITH_CUDA +from .nn import modules +from . import cuda_setup if COMPILED_WITH_CUDA: from .optim import adam -__pdoc__ = {'libbitsandbytes': False, - 'optim.optimizer.Optimizer8bit': False, - 'optim.optimizer.MockArgs': False - } +__pdoc__ = { + "libbitsandbytes": False, + "optim.optimizer.Optimizer8bit": False, + "optim.optimizer.MockArgs": False, +} + +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py new file mode 100644 index 0000000..7f3d24c --- /dev/null +++ b/bitsandbytes/__main__.py @@ -0,0 +1,96 @@ +# from bitsandbytes.debug_cli import cli + +# cli() +import os +import sys +import torch + + +HEADER_WIDTH = 60 + + +def print_header( + txt: str, width: int = HEADER_WIDTH, filler: str = "+" +) -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_debug_info() -> None: + print( + "\nAbove we output some debug information. Please provide this info when " + f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" + ) + + +print_header("") +print_header("DEBUG INFORMATION") +print_header("") +print() + + +from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL +from .cuda_setup.main import get_compute_capabilities +from .cuda_setup.env_vars import to_be_ignored +from .utils import print_stderr + + +print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") +for k, v in os.environ.items(): + if "/" in v and not to_be_ignored(k, v): + print(f"'{k}': '{v}'") +print_header("") + +print( + "\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n" +) + +print_header("OTHER") +print(f"{COMPILED_WITH_CUDA = }") +print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities()}") +print_header("") +print_header("DEBUG INFO END") +print_header("") +print( + """ +Running a quick check that: + + library is importable + + CUDA function is callable +""" +) + +try: + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + + p1 = p.data.sum().item() + + adam = Adam([p]) + + out = a * p + loss = out.sum() + loss.backward() + adam.step() + + p2 = p.data.sum().item() + + assert p1 != p2 + print("SUCCESS!") + print("Installation was successful!") + sys.exit(0) + +except ImportError: + print() + print_stderr( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!" + ) + print_debug_info() + sys.exit(0) +except Exception as e: + print(e) + print_debug_info() + sys.exit(1) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 607d868..14f2660 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,22 +1,24 @@ +from dataclasses import dataclass + import torch import math import bitsandbytes as bnb import bitsandbytes.functional as F -from dataclasses import dataclass - tensor = torch.Tensor -''' +""" This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. -''' +""" + + class GlobalOutlierPooler(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.outliers = set() @@ -30,25 +32,29 @@ class GlobalOutlierPooler(object): return cls._instance def add_outliers(self, outlier_idx, feature_dim): - if self.model_dim is None: self.model_dim = feature_dim - if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer + if self.model_dim is None: + self.model_dim = feature_dim + if feature_dim != self.model_dim: + return # we do not encode outliers for the 2nd FFN layer self.outliers.update(outlier_idx.tolist()) def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -class MatMul8bit(torch.autograd.Function): +class MatMul8bit(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]): + def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): if precision[0] != 8: with torch.no_grad(): output = torch.matmul(A, B) else: - if len(B.shape) == 2: dim = 0 - else: dim = 1 + if len(B.shape) == 2: + dim = 0 + else: + dim = 1 qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) iout = F.igemm(qA, qB) @@ -85,21 +91,43 @@ class MatMul8bit(torch.autograd.Function): else: if len(B.shape) == 2 and len(A.shape) == 3: grad_output = grad_output.contiguous() - if not grad_output.is_contiguous(): grad_output.contiguous() - qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type) - if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) + if not grad_output.is_contiguous(): + grad_output.contiguous() + qgrad_output, S1 = F.vectorwise_quant( + grad_output.view(-1, grad_output.shape[2]), + dim=0, + quant_type=quant_type, + ) + if not A.is_contiguous(): + A = A.contiguous() + qA, S2 = F.vectorwise_quant( + A.view(-1, A.shape[2]), dim=0, quant_type=quant_type + ) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, S2.t(), S1, grad_output.dtype, quant_type + ) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) - qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) + qA, S2 = F.vectorwise_quant( + A, dim=dims, quant_type=quant_type + ) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, + S2.permute(permute_dim), + S1, + grad_output.dtype, + quant_type, + ) if A.requires_grad: - if len(grad_output.shape) == 3: dims = [2] - else: dims = [1] + if len(grad_output.shape) == 3: + dims = [2] + else: + dims = [1] if len(B.shape) == 3: # bio -> boi @@ -114,10 +142,18 @@ class MatMul8bit(torch.autograd.Function): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) - grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type) + grad_A = F.vectorwise_mm_dequant( + igrad_A, + S1, + S3.permute(permute_dim), + grad_output.dtype, + quant_type, + ) return grad_A, grad_B, None, None, None @@ -126,6 +162,7 @@ mm_cublas = MatMul8bit.apply bmm_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply + @dataclass class MatmulLtState: CB = None @@ -160,7 +197,6 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): - @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): # default to pytorch behavior if inputs are empty @@ -183,12 +219,18 @@ class MatMul8bitLt(torch.autograd.Function): requires_gradB = B.requires_grad formatB = state.formatB input_shape = A.shape - if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!' + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + assert ( + A.dtype == torch.float16 + ), f"The input data type needs to be fp16 but {A.dtype} was found!" # 1. Quantize A - if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A, threshold=state.threshold + ) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -202,9 +244,11 @@ class MatMul8bitLt(torch.autograd.Function): if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() - #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: + state.CxB, state.SB = F.transform( + state.CB, to_order=formatB + ) + # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() + # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB # outlier_idx = torch.unique(coo_tensorA.colidx).long() # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) @@ -215,28 +259,34 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = outlier_idx # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - #if state.idx is not None: + # if state.idx is not None: # # extract outliers # CA[:, state.idx] = 0 # CAt[:, state.idx] = 0 # subA = A[:, state.idx] - #else: + # else: # subA = None else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None - # 2. Quantize B if state.has_fp16_weights: - has_grad = (True if (getattr(B, 'grad', None) is not None) else False) + has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) - if is_transposed: B = B.contiguous() + if is_transposed: + B = B.contiguous() if (state.is_training and not has_grad) or state.CxB is None: state.reset_grads() - CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B) + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False @@ -246,14 +296,19 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: # # do not use pool for 2nd FFN layer # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - #else: + # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half() + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .half() + ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -266,7 +321,7 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) output = F.mm_dequant(out32, Sout32, SCA, state.SCB) @@ -289,7 +344,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x clone_func = torch.clone return clone_func(output.view(output_shape)) @@ -302,28 +357,36 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.' + assert ( + state.has_fp16_weights + ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() + grad_output = grad_output.view( + -1, grad_output.shape[-1] + ).contiguous() grad_A = grad_B = None Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, 'col32') + C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view( + ctx.grad_shape + ) return grad_A, grad_B, None, None @@ -331,9 +394,14 @@ class MatMul8bitLt(torch.autograd.Function): matmul = MatMul8bitLt.apply -def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0): +def matmul( + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, +): state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, state) - diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 2374c35..66c79d8 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,15 +1,46 @@ import ctypes as ct -import os +from pathlib import Path from warnings import warn -lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') +from .cuda_setup.main import evaluate_cuda_setup + +class CUDALibrary_Singleton(object): + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + binary_name = evaluate_cuda_setup() + package_dir = Path(__file__).parent + binary_path = package_dir / binary_name + + if not binary_path.exists(): + print(f"TODO: compile library for specific version: {binary_name}") + legacy_binary_name = "libbitsandbytes.so" + print(f"Defaulting to {legacy_binary_name}...") + self.lib = ct.cdll.LoadLibrary(package_dir / legacy_binary_name) + else: + self.lib = ct.cdll.LoadLibrary(package_dir / binary_name) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + +lib = CUDALibrary_Singleton.get_instance().lib try: lib.cadam32bit_g32 lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: - warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers and GPU quantization are unavailable.") + warn( + "The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers and GPU quantization are unavailable." + ) COMPILED_WITH_CUDA = False diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/bitsandbytes/cuda_setup/__init__.py diff --git a/bitsandbytes/cuda_setup/compute_capability.py b/bitsandbytes/cuda_setup/compute_capability.py new file mode 100644 index 0000000..7a3f463 --- /dev/null +++ b/bitsandbytes/cuda_setup/compute_capability.py @@ -0,0 +1,79 @@ +import ctypes +from dataclasses import dataclass, field + + +@dataclass +class CudaLibVals: + # code bits taken from + # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 + + nGpus: ctypes.c_int = field(default=ctypes.c_int()) + cc_major: ctypes.c_int = field(default=ctypes.c_int()) + cc_minor: ctypes.c_int = field(default=ctypes.c_int()) + device: ctypes.c_int = field(default=ctypes.c_int()) + error_str: ctypes.c_char_p = field(default=ctypes.c_char_p()) + cuda: ctypes.CDLL = field(init=False, repr=False) + ccs: List[str, ...] = field(init=False) + + def _initialize_driver_API(self): + self.check_cuda_result(self.cuda.cuInit(0)) + + def _load_cuda_lib(self): + """ + 1. find libcuda.so library (GPU driver) (/usr/lib) + init_device -> init variables -> call function by reference + """ + libnames = "libcuda.so" + for libname in libnames: + try: + self.cuda = ctypes.CDLL(libname) + except OSError: + continue + else: + break + else: + raise OSError("could not load any of: " + " ".join(libnames)) + + def call_cuda_func(self, function_obj, **kwargs): + CUDA_SUCCESS = 0 # constant taken from cuda.h + pass + # if (CUDA_SUCCESS := function_obj( + + def _error_handle(cuda_lib_call_return_value): + """ + 2. call extern C function to determine CC + (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) + """ + CUDA_SUCCESS = 0 # constant taken from cuda.h + + if cuda_lib_call_return_value != CUDA_SUCCESS: + self.cuda.cuGetErrorString( + cuda_lib_call_return_value, + ctypes.byref(self.error_str), + ) + print("Count not initialize CUDA - failure!") + raise Exception("CUDA exception!") + return cuda_lib_call_return_value + + def __post_init__(self): + self._load_cuda_lib() + self._initialize_driver_API() + self.check_cuda_result( + self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus)) + ) + tmp_ccs = [] + for gpu_index in range(self.nGpus.value): + check_cuda_result( + self.cuda, + self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index), + ) + check_cuda_result( + self.cuda, + self.cuda.cuDeviceComputeCapability( + ctypes.byref(self.cc_major), + ctypes.byref(self.cc_minor), + self.device, + ), + ) + tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}") + self.ccs = sorted(tmp_ccs, reverse=True) diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py new file mode 100644 index 0000000..536a7d8 --- /dev/null +++ b/bitsandbytes/cuda_setup/env_vars.py @@ -0,0 +1,51 @@ +import os +from typing import Dict + + +def to_be_ignored(env_var: str, value: str) -> bool: + ignorable = { + "PWD", # PWD: this is how the shell keeps track of the current working dir + "OLDPWD", + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "HOME", # Linux shell default + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "MAIL", # something related to emails + "SHELL", # binary for currently invoked shell + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "PATH", # this is for finding binaries, not libraries + "LESSOPEN", # related to the `less` command + "LESSCLOSE", + "_", # current Python interpreter + } + return env_var in ignorable + + +def might_contain_a_path(candidate: str) -> bool: + return "/" in candidate + + +def is_active_conda_env(env_var: str) -> bool: + return "CONDA_PREFIX" == env_var + + +def is_other_conda_env_var(env_var: str) -> bool: + return "CONDA" in env_var + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return is_active_conda_env(env_var) or ( + might_contain_a_path(value) and not + is_other_conda_env_var(env_var) and not + to_be_ignored(env_var, value) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return { + env_var: value + for env_var, value in os.environ.items() + if is_relevant_candidate_env_var(env_var, value) + } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py new file mode 100644 index 0000000..e96ac70 --- /dev/null +++ b/bitsandbytes/cuda_setup/main.py @@ -0,0 +1,127 @@ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + +import ctypes +from pathlib import Path + +from ..utils import execute_and_return +from .paths import determine_cuda_runtime_lib_path + + +def check_cuda_result(cuda, result_val): + # 3. Check for CUDA errors + if result_val != 0: + error_str = ctypes.c_char_p() + cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) + raise Exception(f"CUDA exception! ERROR: {error_str}") + + +def get_compute_capabilities(): + """ + 1. find libcuda.so library (GPU driver) (/usr/lib) + init_device -> init variables -> call function by reference + 2. call extern C function to determine CC + (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) + 3. Check for CUDA errors + https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api + # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 + """ + + # 1. find libcuda.so library (GPU driver) (/usr/lib) + try: + cuda = ctypes.CDLL("libcuda.so") + except OSError: + # TODO: shouldn't we error or at least warn here? + return None + + nGpus = ctypes.c_int() + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + + result = ctypes.c_int() + device = ctypes.c_int() + + check_cuda_result(cuda, cuda.cuInit(0)) + + check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) + ccs = [] + for i in range(nGpus.value): + check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) + ref_major = ctypes.byref(cc_major) + ref_minor = ctypes.byref(cc_minor) + # 2. call extern C function to determine CC + check_cuda_result( + cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device) + ) + ccs.append(f"{cc_major.value}.{cc_minor.value}") + + return ccs.sort() + + +# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error +def get_compute_capability(): + """ + Extracts the highest compute capbility from all available GPUs, as compute + capabilities are downwards compatible. If no GPUs are detected, it returns + None. + """ + if ccs := get_compute_capabilities() is not None: + # TODO: handle different compute capabilities; for now, take the max + return ccs[-1] + return None + + +def evaluate_cuda_setup(): + cuda_path = determine_cuda_runtime_lib_path() + print(f"CUDA SETUP: CUDA path found: {cuda_path}") + cc = get_compute_capability() + binary_name = "libbitsandbytes_cpu.so" + + # FIXME: has_gpu is still unused + if not (has_gpu := bool(cc)): + print( + "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..." + ) + return binary_name + + # 7.5 is the minimum CC vor cublaslt + has_cublaslt = cc in ["7.5", "8.0", "8.6"] + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + # FIXME: cuda_home is still unused + cuda_home = str(Path(cuda_path).parent.parent) + # we use ls -l instead of nvcc to determine the cuda version + # since most installations will have the libcudart.so installed, but not the compiler + ls_output, err = execute_and_return(f"ls -l {cuda_path}") + major, minor, revision = ( + ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".") + ) + cuda_version_string = f"{major}{minor}" + + def get_binary_name(): + "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" + bin_base_name = "libbitsandbytes_cuda" + if has_cublaslt: + return f"{bin_base_name}{cuda_version_string}.so" + else: + return f"{bin_base_name}_nocublaslt.so" + + return binary_name diff --git a/bitsandbytes/cuda_setup/paths.py b/bitsandbytes/cuda_setup/paths.py new file mode 100644 index 0000000..c4a7465 --- /dev/null +++ b/bitsandbytes/cuda_setup/paths.py @@ -0,0 +1,126 @@ +from pathlib import Path +from typing import Set, Union +from warnings import warn + +from ..utils import print_stderr +from .env_vars import get_potentially_lib_path_containing_env_vars + + +CUDA_RUNTIME_LIB: str = "libcudart.so" + + +def purge_unwanted_semicolon(tentative_path: Path) -> Path: + """ + Special function to handle the following exception: + __LMOD_REF_COUNT_PATH=/sw/cuda/11.6.2/bin:2;/mmfs1/home/dettmers/git/sched/bin:1;/mmfs1/home/dettmers/data/anaconda3/bin:1;/mmfs1/home/dettmers/data/anaconda3/condabin:1;/mmfs1/home/dettmers/.local/bin:1;/mmfs1/home/dettmers/bin:1;/usr/local/bin:1;/usr/bin:1;/usr/local/sbin:1;/usr/sbin:1;/mmfs1/home/dettmers/.fzf/bin:1;/mmfs1/home/dettmers/data/local/cuda-11.4/bin:1 + """ + # if ';' in str(tentative_path): + # path_as_str, _ = str(tentative_path).split(';') + pass + + +def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]: + return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} + + +def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: + non_existent_directories: Set[Path] = { + path for path in candidate_paths if not path.exists() + } + + if non_existent_directories: + print_stderr( + "WARNING: The following directories listed in your path were found to " + f"be non-existent: {non_existent_directories}" + ) + + return candidate_paths - non_existent_directories + + +def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: + return { + path / CUDA_RUNTIME_LIB + for path in candidate_paths + if (path / CUDA_RUNTIME_LIB).is_file() + } + + +def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: + """ + Searches a given environmental var for the CUDA runtime library, + i.e. `libcudart.so`. + """ + return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) + + +def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: + return get_cuda_runtime_lib_paths( + resolve_paths_list(paths_list_candidate) + ) + + +def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: + if len(results_paths) > 1: + warning_msg = ( + f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " + "We'll flip a coin and try one of these, in order to fail forward.\n" + "Either way, this might cause trouble in the future:\n" + "If you get `CUDA error: invalid device function` errors, the above " + "might be the cause and the solution is to make sure only one " + f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env." + ) + warn(warning_msg) + + +def determine_cuda_runtime_lib_path() -> Union[Path, None]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + if "CONDA_PREFIX" in candidate_env_vars: + conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" + + conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) + warn_in_case_of_duplicates(conda_cuda_libs) + + if conda_cuda_libs: + return next(iter(conda_cuda_libs)) + + warn( + f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' + f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...' + ) + + if "LD_LIBRARY_PATH" in candidate_env_vars: + lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) + + if lib_ld_cuda_libs: + return next(iter(lib_ld_cuda_libs)) + warn_in_case_of_duplicates(lib_ld_cuda_libs) + + warn( + f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' + f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...' + ) + + remaining_candidate_env_vars = { + env_var: value for env_var, value in candidate_env_vars.items() + if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} + } + + cuda_runtime_libs = set() + for env_var, value in remaining_candidate_env_vars: + cuda_runtime_libs.update(find_cuda_lib_in(value)) + + warn_in_case_of_duplicates(cuda_runtime_libs) + + return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else set() diff --git a/bitsandbytes/debug_cli.py b/bitsandbytes/debug_cli.py new file mode 100644 index 0000000..4306bc0 --- /dev/null +++ b/bitsandbytes/debug_cli.py @@ -0,0 +1,26 @@ +import typer + +cli = typer.Typer() + + +@cli.callback() +def callback(): + """ + Awesome Portal Gun + """ + + +@cli.command() +def shoot(): + """ + Shoot the portal gun + """ + typer.echo("Shooting portal gun") + + +@cli.command() +def load(): + """ + Load the portal gun + """ + typer.echo("Loading portal gun") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ad85f53..b4409e4 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct import random @@ -10,47 +10,86 @@ import torch from typing import Tuple from torch import Tensor -from .cextension import lib, COMPILED_WITH_CUDA +from .cextension import COMPILED_WITH_CUDA, lib name2qmap = {} if COMPILED_WITH_CUDA: - ''' C FUNCTIONS FOR OPTIMIZERS ''' + """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) - str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) - str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) - str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["momentum"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) + str2optimizer32bit["rmsprop"] = ( + lib.crmsprop32bit_g32, + lib.crmsprop32bit_g16, + ) + str2optimizer32bit["adagrad"] = ( + lib.cadagrad32bit_g32, + lib.cadagrad32bit_g16, + ) + str2optimizer32bit["lars"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) + str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} - str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) - str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16) - str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16) - str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) - str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16) + str2optimizer8bit["adam"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) + str2optimizer8bit["momentum"] = ( + lib.cmomentum_static_8bit_g32, + lib.cmomentum_static_8bit_g16, + ) + str2optimizer8bit["rmsprop"] = ( + lib.crmsprop_static_8bit_g32, + lib.crmsprop_static_8bit_g16, + ) + str2optimizer8bit["lamb"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) + str2optimizer8bit["lars"] = ( + lib.cmomentum_static_8bit_g32, + lib.cmomentum_static_8bit_g16, + ) str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16) + str2optimizer8bit_blockwise["adam"] = ( + lib.cadam_8bit_blockwise_fp32, + lib.cadam_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["momentum"] = ( + lib.cmomentum_8bit_blockwise_fp32, + lib.cmomentum_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["rmsprop"] = ( + lib.crmsprop_8bit_blockwise_fp32, + lib.crmsprop_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["adagrad"] = ( + lib.cadagrad_8bit_blockwise_fp32, + lib.cadagrad_8bit_blockwise_fp16, + ) class CUBLAS_Context(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = {} - #prev_device = torch.cuda.current_device() - #for i in range(torch.cuda.device_count()): + # prev_device = torch.cuda.current_device() + # for i in range(torch.cuda.device_count()): # torch.cuda.set_device(torch.device('cuda', i)) # self.context.append(ct.c_void_p(lib.get_context())) - #torch.cuda.set_device(prev_device) + # torch.cuda.set_device(prev_device) @classmethod def get_instance(cls): @@ -67,11 +106,12 @@ class CUBLAS_Context(object): torch.cuda.set_device(prev_device) return self.context[device.index] + class Cusparse_Context(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = ct.c_void_p(lib.get_cusparse()) @@ -83,14 +123,16 @@ class Cusparse_Context(object): cls._instance.initialize() return cls._instance + def create_linear_map(signed=True): if signed: return torch.linspace(-1.0, 1.0, 256) else: return torch.linspace(0.0, 1.0, 256) + def create_dynamic_map(signed=True, n=7): - ''' + """ Creates the dynamic quantiztion map. The dynamic data type is made up of a dynamic exponent and @@ -104,43 +146,53 @@ def create_dynamic_map(signed=True, n=7): For more details see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] - ''' + """ data = [] # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - additional_items = 2**(7-n)-1 - if not signed: additional_items = 2*additional_items + additional_items = 2 ** (7 - n) - 1 + if not signed: + additional_items = 2 * additional_items for i in range(n): - fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1 + fraction_items = ( + 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + ) boundaries = torch.linspace(0.1, 1, fraction_items) - means = (boundaries[:-1]+boundaries[1:])/2.0 - data += ((10**(-(n-1)+i))*means).tolist() + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(n - 1) + i)) * means).tolist() if signed: - data += (-(10**(-(n-1)+i))*means).tolist() + data += (-(10 ** (-(n - 1) + i)) * means).tolist() if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items+1) - means = (boundaries[:-1]+boundaries[1:])/2.0 - data += ((10**(-(n-1)+i))*means).tolist() + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(n - 1) + i)) * means).tolist() if signed: - data += (-(10**(-(n-1)+i))*means).tolist() + data += (-(10 ** (-(n - 1) + i)) * means).tolist() data.append(0) data.append(1.0) data.sort() return Tensor(data) + def get_special_format_str(): major, minor = torch.cuda.get_device_capability() if major < 7: - print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + print( + f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!" + ) assert major >= 7 - if major == 7: return 'col_turing' - elif major == 8: return 'col_ampere' - else: return 'col_turing' + if major == 7: + return "col_turing" + elif major == 8: + return "col_ampere" + else: + return "col_turing" + def is_on_gpu(tensors): @@ -151,7 +203,7 @@ def is_on_gpu(tensors): return on_gpu def get_ptr(A: Tensor) -> ct.c_void_p: - ''' + """ Get the ctypes pointer from a PyTorch Tensor. Parameters @@ -162,31 +214,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p: Returns ------- ctypes.c_void_p - ''' - if A is None: return None - else: return ct.c_void_p(A.data.storage().data_ptr()) + """ + if A is None: + return None + else: + return ct.c_void_p(A.data.storage().data_ptr()) + def pre_call(device): prev_device = torch.cuda.current_device() torch.cuda.set_device(device) return prev_device + def post_call(prev_device): torch.cuda.set_device(prev_device) + def get_transform_func(dtype, orderA, orderOut, transpose=False): name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' if not hasattr(lib, name): print(name) - raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}') + raise ValueError( + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + ) else: return getattr(lib, name) + class GlobalData(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.data = {} @@ -199,15 +259,17 @@ class GlobalData(object): return cls._instance -def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False): - #init_func = torch.empty +def get_transform_buffer( + shape, dtype, device, to_order, from_order="row", transpose=False +): + # init_func = torch.empty init_func = torch.zeros dims = len(shape) if dims == 2: rows = shape[0] elif dims == 3: - rows = shape[0]*shape[1] + rows = shape[0] * shape[1] cols = shape[-1] state = (shape, to_order) @@ -218,30 +280,45 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans cols = tmp state = (shape[::-1], to_order) - if to_order == 'row' or to_order == 'col': + if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state - elif to_order == 'col32': + elif to_order == "col32": # blocks of 32 columns (padded) - cols = 32*((cols+31)//32) + cols = 32 * ((cols + 31) // 32) return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == 'col_turing': + elif to_order == "col_turing": # blocks of 32 columns and 8 rows - cols = 32*((cols+31)//32) - rows = 8*((rows+7)//8) + cols = 32 * ((cols + 31) // 32) + rows = 8 * ((rows + 7) // 8) return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == 'col_ampere': + elif to_order == "col_ampere": # blocks of 32 columns and 32 rows - cols = 32*((cols+31)//32) - rows = 32*((rows+31)//32) + cols = 32 * ((cols + 31) // 32) + rows = 32 * ((rows + 31) // 32) return init_func((rows, cols), dtype=dtype, device=device), state else: - raise NotImplementedError(f'To_order not supported: {to_order}') - -def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) - else: new_state = (state[1], to_order) + raise NotImplementedError(f"To_order not supported: {to_order}") + + +def nvidia_transform( + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1] + ) + else: + new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) shape = state[0] @@ -251,10 +328,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s elif ld is not None: n = math.prod(shape) dim1 = math.prod([shape[i] for i in ld]) - dim2 = ct.c_int32(n//dim1) + dim2 = ct.c_int32(n // dim1) dim1 = ct.c_int32(dim1) else: - dim1 = ct.c_int32(shape[0]*shape[1]) + dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) ptr = CUBLAS_Context.get_instance().get_context(A.device) @@ -262,10 +339,12 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s ptrOut = get_ptr(out) func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) - return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor: + +def estimate_quantiles( + A: Tensor, out: Tensor = None, offset: float = 1 / 512 +) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -295,15 +374,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) is_on_gpu([A, out]) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp32( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp16( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) else: - raise NotImplementedError(f'Not supported data type {A.dtype}') + raise NotImplementedError(f"Not supported data type {A.dtype}") return out -def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor: - ''' + +def quantize_blockwise( + A: Tensor, + code: Tensor = None, + absmax: Tensor = None, + rand=None, + out: Tensor = None, +) -> Tensor: + """ Quantize tensor A in blocks of size 4096 values. Quantizes tensor A by dividing it into blocks of 4096 values. @@ -329,22 +419,23 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N The 8-bit tensor. tuple(torch.Tensor, torch.Tensor): The quantization state to undo the quantization. - ''' + """ if code is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) if absmax is None: n = A.numel() num_blocks = 4096 - blocks = n//num_blocks + blocks = n // num_blocks blocks += 1 if n % num_blocks > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) - + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': is_on_gpu([code, A, absmax, out, rand]) @@ -352,29 +443,73 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + get_ptr(rand), + ct.c_int32(rand_offset), + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + get_ptr(rand), + ct.c_int32(rand_offset), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: # cpu assert rand is None - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) return out, (absmax, code) -def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, - absmax: Tensor=None, code: Tensor=None, out: Tensor=None, - blocksize: int=4096) -> Tensor: - ''' + +def dequantize_blockwise( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, + blocksize: int = 4096, +) -> Tensor: + """ Dequantizes blockwise quantized values. Dequantizes the tensor A with maximum absolute values absmax in @@ -385,7 +520,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, A : torch.Tensor The input 8-bit tensor. quant_state : tuple(torch.Tensor, torch.Tensor) - Tuple of code and absmax values. + Tuple of code and absmax values. absmax : torch.Tensor The absmax values. code : torch.Tensor @@ -398,57 +533,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, ------- torch.Tensor: Dequantized tensor (default: float32) - ''' + """ assert quant_state is not None or absmax is not None if code is None and quant_state is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) - if quant_state is None: quant_state = (absmax, code) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) + if quant_state is None: + quant_state = (absmax, code) if blocksize not in [2048, 4096]: - raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" + ) if A.device.type != 'cpu': is_on_gpu([A, out]) if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: - lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel())) - + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(A.numel()), + ) return out -def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor: +def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) absmax = torch.abs(A).max() - inp = A/absmax + inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) -def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor: + +def dequantize( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, +) -> Tensor: assert quant_state is not None or absmax is not None if code is None and quant_state is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) - if quant_state is None: quant_state = (absmax, code) + if quant_state is None: + quant_state = (absmax, code) out = dequantize_no_absmax(A, quant_state[1], out) - return out*quant_state[0] + return out * quant_state[0] + -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: +def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ''' Quantizes input tensor to 8-bit. @@ -474,7 +646,8 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: + +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ''' Dequantizes the 8-bit tensor to 32-bit. @@ -500,12 +673,25 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out -def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor, - beta1: float, eps: float, step: int, lr: float, - state2: Tensor=None, beta2: float=0.0, - weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None: - ''' + +def optimizer_update_32bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Tensor = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + """ Performs an inplace optimizer update with one or two optimizer states. Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. @@ -542,33 +728,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten The maximum update norm relative to the weight norm. skip_zeros : bool Whether to skip zero-valued gradients or not (default: False). - ''' + """ param_norm = 0.0 if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) if optimizer_name not in str2optimizer32bit: - raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}') + raise NotImplementedError( + f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' + ) if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), - ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer32bit[optimizer_name][0]( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), - ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer32bit[optimizer_name][1]( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') - -def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, - beta1: float, beta2: float, eps: float, - step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor, - weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: - ''' + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) + + +def optimizer_update_8bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + max1: Tensor, + max2: Tensor, + new_max1: Tensor, + new_max2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, +) -> None: + """ Performs an inplace Adam update. Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. @@ -616,56 +853,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten The tensor for the update norm. max_unorm : float The maximum update norm relative to the weight norm. - ''' + """ param_norm = 0.0 if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), - get_ptr(qmap1), get_ptr(qmap2), - get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), - ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + str2optimizer8bit[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), - get_ptr(qmap1), get_ptr(qmap2), - get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), - ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + str2optimizer8bit[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') - - -def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, - beta1: float, beta2: float, eps: float, - step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0, - skip_zeros=False) -> None: - + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + absmax1: Tensor, + absmax2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer8bit_blockwise[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer8bit_blockwise[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) -def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5): +def percentile_clipping( + grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 +): """Applies percentile clipping grad: torch.Tensor @@ -678,11 +994,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: """ is_on_gpu([grad, gnorm_vec]) if grad.dtype == torch.float32: - lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) + lib.cpercentile_clipping_g32( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) elif grad.dtype == torch.float16: - lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) + lib.cpercentile_clipping_g16( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) else: - raise ValueError(f'Gradient type {grad.dtype} not supported!') + raise ValueError(f"Gradient type {grad.dtype} not supported!") current_gnorm = torch.sqrt(gnorm_vec[step % 100]) vals, idx = torch.sort(gnorm_vec) @@ -690,22 +1016,24 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: gnorm_scale = 1.0 if current_gnorm > clip_value: - gnorm_scale = clip_value/current_gnorm + gnorm_scale = clip_value / current_gnorm return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): +def histogram_scatter_add_2d( + histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor +): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 assert index1.dtype == torch.int32 assert index2.dtype == torch.int32 - assert histogram.device.type == 'cuda' - assert index1.device.type == 'cuda' - assert index2.device.type == 'cuda' - assert source.device.type == 'cuda' + assert histogram.device.type == "cuda" + assert index1.device.type == "cuda" + assert index2.device.type == "cuda" + assert source.device.type == "cuda" maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) @@ -715,7 +1043,9 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}') + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) sA = A.shape sB = B.shape @@ -725,64 +1055,105 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 correct = True if len(sA) == 2 and len(sB) == 2: - if not tA and not tB and A.shape[1] != B.shape[0]: correct = False - elif tA and not tB and A.shape[0] != B.shape[0]: correct = False - elif tA and tB and A.shape[0] != B.shape[1]: correct = False - elif not tA and tB and A.shape[1] != B.shape[1]: correct = False + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB and A.shape[2] != B.shape[0]: correct = False - elif tA and not tB and A.shape[1] != B.shape[0]: correct = False - elif tA and tB and A.shape[1] != B.shape[1]: correct = False - elif not tA and tB and A.shape[2] != B.shape[1]: correct = False + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB and A.shape[2] != B.shape[1]: correct = False - elif tA and not tB and A.shape[1] != B.shape[1]: correct = False - elif tA and tB and A.shape[1] != B.shape[2]: correct = False - elif not tA and tB and A.shape[2] != B.shape[2]: correct = False + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False if out is not None: sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if (sout[0] == sA[2] and sout[1] == sB[2] and - sA[0] == sB[0] and sA[1] == sB[1]): + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): correct = True else: if len(sA) == 2 and len(sB) == 2: - if not tA and not tB: sout = (sA[0], sB[1]) - elif tA and tB: sout = (sA[1], sB[0]) - elif tA and not tB: sout = (sA[1], sB[1]) - elif not tA and tB: sout = (sA[0], sB[0]) + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB: sout = (sA[0], sA[1], sB[1]) - elif tA and tB: sout = (sA[0], sA[2], sB[0]) - elif tA and not tB: sout = (sA[0], sA[2], sB[1]) - elif not tA and tB: sout = (sA[0], sA[1], sB[0]) + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB: sout = (sA[0], sA[1], sB[2]) - elif tA and tB: sout = (sA[0], sA[2], sB[1]) - elif tA and not tB: sout = (sA[0], sA[2], sB[2]) - elif not tA and tB: sout = (sA[0], sA[1], sB[1]) - + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) if not correct: - raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.') + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + ) return sout -def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): + +def igemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, +): sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) if len(A.shape) == 3 and len(B.shape) == 3: if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]: return batched_igemm(A, B, out) sA = A.shape sB = B.shape - if transposed_A and len(sA) == 2: sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0]) + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) # this is a mess: cuBLAS expect column major, but PyTorch is row major. # So to perform the matrix multiplication, we have to treat A, B, and C matrices # (transpose of row major is column major) @@ -793,23 +1164,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] if len(sB) == 2: - if B.stride()[0] == B.shape[1]: transposed_B = False - elif B.stride()[1] == B.shape[0]: transposed_B = True + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: transposed_A = False - elif A.stride()[1] == A.shape[0]: transposed_A = True + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True else: - if A.stride()[1] == A.shape[2]: transposed_A = False - elif A.stride()[2] == A.shape[1]: transposed_A = True + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True if len(sA) == 2: n = sA[0] ldb = A.stride()[1 if transposed_A else 0] elif len(sA) == 3 and len(sB) == 2: - n = sA[0]*sA[1] + n = sA[0] * sA[1] ldb = sA[2] - m = sB[1] k = sB[0] lda = B.stride()[(1 if transposed_B else 0)] @@ -818,20 +1194,21 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # special case assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}') + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) transposed_A = True transposed_B = False m = sB[2] n = sA[2] - k = sB[0]*sB[1] + k = sB[0] * sB[1] lda = m ldb = sA[2] ldc = m - ptr = CUBLAS_Context.get_instance().get_context(A.device) # B^T @ A^T = C^T @@ -842,11 +1219,20 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed return out -def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): +def batched_igemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, +): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}') + raise ValueError( + f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" + ) sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) if B.is_contiguous(): lda = B.stride()[1] @@ -903,9 +1289,9 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ldc = m - strideA = B.shape[1]*B.shape[2] - strideB = A.shape[1]*A.shape[2] - strideC = A.shape[1]*B.shape[2] + strideA = B.shape[1] * B.shape[2] + strideB = A.shape[1] * A.shape[2] + strideC = A.shape[1] * B.shape[2] ptr = CUBLAS_Context.get_instance().get_context(A.device) @@ -915,6 +1301,7 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) return out + def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] @@ -924,7 +1311,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): if dimsA == 2: m = shapeA[0] elif dimsA == 3: - m = shapeA[0]*shapeA[1] + m = shapeA[0] * shapeA[1] rows = n = shapeB[0] assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' @@ -936,20 +1323,26 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) - assert dimsB != 3, 'len(B.shape)==3 not supported' - assert A.device.type == 'cuda' - assert B.device.type == 'cuda' + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" assert A.dtype == torch.int8 assert B.dtype == torch.int8 assert out.dtype == dtype - assert SA[1] == 'col32' - assert SB[1] in ['col_turing', 'col_ampere'] - assert Sout[1] == 'col32' - assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}' + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" formatB = SB[1] prev_device = A.device torch.cuda.set_device(A.device) @@ -960,17 +1353,17 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ptrC = get_ptr(out) k = shapeA[-1] - lda = ct.c_int32(m*32) - if formatB == 'col_turing': + lda = ct.c_int32(m * 32) + if formatB == "col_turing": # turing: tiles with rows filled up to multiple of 8 rows by 32 columns # n = rows - ldb = ct.c_int32(((rows+7)//8)*8*32) + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) else: # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns # n = rows - ldb = ct.c_int32(((rows+31)//32)*32*32) + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - ldc = ct.c_int32(m*32) + ldc = ct.c_int32(m * 32) m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) @@ -980,14 +1373,22 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): is_on_gpu([A, B, out]) if formatB == 'col_turing': if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif formatB == 'col_ampere': + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) if has_error == 1: print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') @@ -995,20 +1396,39 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): torch.cuda.set_device(prev_device) - return out, Sout -def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None): +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, +): assert A.dtype == torch.int32 out_shape = quant_state[0] - if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2]) - - if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) - if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) + if new_col_stats is None: + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" ptrA = get_ptr(A) ptrOut = get_ptr(out) @@ -1025,22 +1445,33 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non return out -def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): assert A.dtype == torch.float16 device = A.device cols = A.shape[-1] if len(A.shape) == 3: - rows = A.shape[0]*A.shape[1] + rows = A.shape[0] * A.shape[1] else: rows = A.shape[0] - col_tiles = (cols+255)//256 - tiled_rows = ((rows+15)//16)*16 - if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) - if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device) + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) + if col_stats is None: + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) + + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -1054,13 +1485,12 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) post_call(prev_device) - if threshold > 0.0: nnz_block_ptr.cumsum_(0) - return row_stats, col_stats, nnz_block_ptr + class COOSparseTensor(object): def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 @@ -1077,6 +1507,7 @@ class COOSparseTensor(object): self.colidx = colidx self.values = values + class CSRSparseTensor(object): def __init__(self, rows, cols, nnz, rowptr, colidx, values): assert rowptr.dtype == torch.int32 @@ -1084,7 +1515,7 @@ class CSRSparseTensor(object): assert values.dtype == torch.float16 assert values.numel() == nnz assert colidx.numel() == nnz - assert rowptr.numel() == rows+1 + assert rowptr.numel() == rows + 1 self.rows = rows self.cols = cols @@ -1093,6 +1524,7 @@ class CSRSparseTensor(object): self.colidx = colidx self.values = values + class CSCSparseTensor(object): def __init__(self, rows, cols, nnz, colptr, rowidx, values): assert colptr.dtype == torch.int32 @@ -1100,7 +1532,7 @@ class CSCSparseTensor(object): assert values.dtype == torch.float16 assert values.numel() == nnz assert rowidx.numel() == nnz - assert colptr.numel() == cols+1 + assert colptr.numel() == cols + 1 self.rows = rows self.cols = cols @@ -1109,13 +1541,19 @@ class CSCSparseTensor(object): self.rowidx = rowidx self.values = values + def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) + return CSRSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values + ) + def coo2csc(cooA): val, col2rowidx = torch.sort(cooA.colidx) @@ -1123,10 +1561,15 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) + def coo_zeros(rows, cols, nnz, device, dtype=torch.half): rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) @@ -1135,23 +1578,29 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): device = A.device assert A.dtype == torch.half - assert device.type == 'cuda' + assert device.type == "cuda" prev_device = pre_call(A.device) cols = A.shape[-1] if len(A.shape) == 3: - rows = A.shape[0]*A.shape[1] + rows = A.shape[0] * A.shape[1] else: rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) - if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) coo_tensor = None ptrA = get_ptr(A) @@ -1164,21 +1613,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) ptrRowPtr = get_ptr(nnz_row_ptr) - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) val, idx = torch.sort(coo_tensor.rowidx) coo_tensor.rowidx = val coo_tensor.colidx = coo_tensor.colidx[idx] coo_tensor.values = coo_tensor.values[idx] else: - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) else: - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) post_call(prev_device) return out_row, out_col, row_stats, col_stats, coo_tensor @@ -1187,7 +1677,9 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, def get_special_format_str(): major, minor = torch.cuda.get_device_capability() if major < 7: - print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + print( + f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!" + ) assert major >= 7 if major == 7: return 'col_turing' @@ -1209,7 +1701,7 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No dim1 = ct.c_int32(shape[0]) dim2 = ct.c_int32(shape[1]) else: - dim1 = ct.c_int32(shape[0]*shape[1]) + dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) ptrA = get_ptr(A) @@ -1220,20 +1712,20 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'col_turing': + elif to_order == "col_turing": if transpose: lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'col_ampere': + elif to_order == "col_ampere": if transpose: lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'row': - if from_order == 'col_turing': + elif to_order == "row": + if from_order == "col_turing": lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == 'col_ampere': + elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') @@ -1242,15 +1734,19 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No return out, new_state + def spmm_coo(cooA, B, out=None): - if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + if out is None: + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz assert cooA.cols == B.shape[0] - transposed_B = (False if B.is_contiguous() else True) + transposed_B = False if B.is_contiguous() else True ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -1274,15 +1770,19 @@ def spmm_coo(cooA, B, out=None): return out + def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): - if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) + if out is None: + out = torch.zeros( + (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}' + assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" - transposed_B = (False if B.is_contiguous() else True) + transposed_B = False if B.is_contiguous() else True ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -1292,7 +1792,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.' + assert ( + max_count[0] <= 32 + ), f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -1312,137 +1814,188 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB = ct.c_int32(B.shape[1]) cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - #print(cooA.rowidx[:64]) - #print(cooA.colidx[:64].sort()[0]) + # print(cooA.rowidx[:64]) + # print(cooA.colidx[:64].sort()[0]) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: - lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) + lib.cspmm_coo_very_sparse_naive_fp16( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) elif B.dtype == torch.int8: - lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) - #else: assertion error + lib.cspmm_coo_very_sparse_naive_int8( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + # else: assertion error return out C = 127.0 -def vectorwise_quant(x, dim=1, quant_type='vector'): - if quant_type == 'linear': + +def vectorwise_quant(x, dim=1, quant_type="vector"): + if quant_type == "linear": max1 = torch.abs(x).max().float() - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type in ['vector', 'row']: + elif quant_type in ["vector", "row"]: max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x*(C/max1)).to(torch.int8) + xq = torch.round(x * (C / max1)).to(torch.int8) return xq, max1 - elif quant_type == 'zeropoint': + elif quant_type == "zeropoint": dtype = x.dtype x = x.float() dyna = x.max() - x.min() - if dyna == 0: dyna = 1 - qx = 255./dyna + if dyna == 0: + dyna = 1 + qx = 255.0 / dyna minx = x.min() - zpx = torch.round(minx* qx) - x = torch.round(qx*x - zpx) + zpx + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx return x, qx - elif quant_type in ['vector-zeropoint', 'row-zeropoint']: + elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)) - dyna[dyna==0] = 1 - qx = 255./dyna + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( + x, dim=dim, keepdim=True + ) + dyna[dyna == 0] = 1 + qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) - zpx = torch.round(minx* qx) - x = torch.round(qx*x - zpx) + zpx + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx return x, qx - elif quant_type == 'truncated-vector': + elif quant_type == "truncated-vector": with torch.no_grad(): absx = torch.abs(x) max1 = torch.amax(absx, dim=dim, keepdim=True) - max1 = max1*0.7 - idx = (absx > max1.expand_as(absx)) + max1 = max1 * 0.7 + idx = absx > max1.expand_as(absx) sign = torch.sign(x[idx]) - x[idx] = max1.expand_as(absx)[idx]*sign - xq = torch.round(x/max1*C).to(torch.int8) + x[idx] = max1.expand_as(absx)[idx] * sign + xq = torch.round(x / max1 * C).to(torch.int8) return xq, max1 - else: return None + else: + return None + -def vectorwise_dequant(xq, max1, quant_type='vector'): - if quant_type == 'vector': - x = (xq/C*max1).to(torch.float32) +def vectorwise_dequant(xq, max1, quant_type="vector"): + if quant_type == "vector": + x = (xq / C * max1).to(torch.float32) return x - else: return None + else: + return None + -def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'): - if quant_type == 'linear': - norm = S1*S2/(C*C) +def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): + if quant_type == "linear": + norm = S1 * S2 / (C * C) # double cast needed to prevent overflows - return (xq.float()*norm).to(dtype) - elif quant_type == 'zeropoint': - norm = 1.0/(S1*S2) - return (xq.float()*norm).to(dtype) - elif quant_type == 'row-zeropoint': - norm = 1.0/(S1*S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "zeropoint": + norm = 1.0 / (S1 * S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "row-zeropoint": + norm = 1.0 / (S1 * S2) x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: x *= norm else: x *= norm return x.to(dtype) - elif quant_type == 'vector-zeropoint': + elif quant_type == "vector-zeropoint": x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= 1.0/S1 + x *= 1.0 / S1 else: - x *= 1.0/S1 - x *= 1.0/S2.t() + x *= 1.0 / S1 + x *= 1.0 / S2.t() return x.to(dtype) - elif quant_type == 'row': + elif quant_type == "row": x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= S1*S2/(C*C) + x *= S1 * S2 / (C * C) else: - x *= S1*S2/(C*C) + x *= S1 * S2 / (C * C) return x.to(dtype) - elif quant_type in ['truncated-vector', 'vector']: + elif quant_type in ["truncated-vector", "vector"]: x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= S1/C + x *= S1 / C else: - x *= S1/C - x *= S2/C + x *= S1 / C + x *= S2 / C return x.to(dtype) - else: return None + else: + return None def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): - offset = B.float().t().sum(0)*(SA[0]+SA[1]) + offset = B.float().t().sum(0) * (SA[0] + SA[1]) x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) if len(SB.shape) == 2: - x *= SB.t()/127 + x *= SB.t() / 127 else: - x *= SB/127 - x *= SA[1]/127 - x +=offset + x *= SB / 127 + x *= SA[1] / 127 + x += offset return x.to(dtype) + def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - assert formatA in ['col_turing', 'col_ampere'] - assert A.device.type == 'cuda' + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -1454,12 +2007,8 @@ def extract_outliers(A, SA, idx): prev_device = pre_call(A.device) if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == 'col_ampere': + elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) post_call(prev_device) return out - - - - diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 03b4655..98d4aa0 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params +from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5013d0b..454dba5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1,39 +1,70 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch -import bitsandbytes as bnb +from typing import ( + Any, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Set, + Tuple, + TypeVar, + Union, + overload, +) -from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict - -from torch import Tensor, device, dtype -from torch import nn -from torch.nn.parameter import Parameter +import torch import torch.nn.functional as F +from torch import Tensor, device, dtype, nn +from torch.nn.parameter import Parameter +import bitsandbytes as bnb from bitsandbytes.optim import GlobalOptimManager -T = TypeVar('T', bound='torch.nn.Module') +T = TypeVar("T", bound="torch.nn.Module") + class StableEmbedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None) -> None: - super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(StableEmbedding, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) self.norm = torch.nn.LayerNorm(embedding_dim) - GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. - ''' + """ + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): @@ -41,29 +72,55 @@ class StableEmbedding(torch.nn.Embedding): def forward(self, input: Tensor) -> Tensor: emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return self.norm(emb) class Embedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None) -> None: - super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) - GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(Embedding, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. - ''' + """ + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): @@ -71,13 +128,27 @@ class Embedding(torch.nn.Embedding): def forward(self, input: Tensor) -> Tensor: emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return emb + class Int8Params(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None): + def __new__( + cls, + data=None, + requires_grad=True, + has_fp16_weights=False, + CB=None, + SCB=None, + ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None @@ -96,14 +167,18 @@ class Int8Params(torch.nn.Parameter): del CBt del SCBt self.data = CB - setattr(self, 'CB', CB) - setattr(self, 'SCB', SCB) + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) return self @overload - def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., - non_blocking: bool = ...) -> T: + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: ... @overload @@ -115,30 +190,54 @@ class Int8Params(torch.nn.Parameter): ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - - if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + + if ( + device is not None + and device.type == "cuda" + and self.data.device.type == "cpu" + ): + return self.cuda(device) else: - new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights) + new_param = Int8Params( + super().to( + device=device, dtype=dtype, non_blocking=non_blocking + ), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) new_param.CB = self.CB new_param.SCB = self.SCB return new_param - class Linear8bitLt(nn.Linear): - def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None): - super(Linear8bitLt, self).__init__(input_features, output_features, bias) + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + threshold=0.0, + index=None, + ): + super(Linear8bitLt, self).__init__( + input_features, output_features, bias + ) self.state = bnb.MatmulLtState() - self.index=index + self.index = index self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -149,9 +248,10 @@ class Linear8bitLt(nn.Linear): def forward(self, x): self.state.is_training = self.training - if self.weight.CB is not None: self.init_8bit_state() - #assert not self.state.has_fp16_weights - #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None + if self.weight.CB is not None: + self.init_8bit_state() + # assert not self.state.has_fp16_weights + # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None out = bnb.matmul(x, self.weight, state=self.state) @@ -166,8 +266,18 @@ class Linear8bitLt(nn.Linear): return out + class Linear8bit(nn.Linear): - def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False): + def __init__( + self, + input_features, + output_features, + bias=True, + quant_type="vector", + index=None, + args=None, + sparse_decomp=False, + ): super(Linear8bit, self).__init__(input_features, output_features, bias) self.quant_type = quant_type self.index = index @@ -178,15 +288,24 @@ class Linear8bit(nn.Linear): self.iter += 1 if self.iter % self.args.clip_freq == 0: with torch.no_grad(): - maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx) + maxval, maxidx = torch.topk( + torch.abs(self.weight.flatten()), k=self.args.clip_idx + ) if not dist.is_initialized() or dist.get_rank() == 0: - print('clip', maxval[-1].item()) + print("clip", maxval[-1].item()) self.weight.clip_(-maxval[-1], maxval[-1]) - if self.args is not None: - out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type) + out = bnb.nn.functional.sparse_decomposed_linear8bit( + x, + self.weight, + self.bias, + qval=self.args.sparse_decomp_val, + quant_type=self.args.quant_type, + ) else: - out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type) + out = bnb.nn.functional.linear8bit( + x, self.weight, self.bias, quant_type=self.args.quant_type + ) return out diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 42b5bc0..a76d717 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.cextension import COMPILED_WITH_CUDA diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 4f51250..7e2f566 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -1,54 +1,132 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Adagrad(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') - super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise ValueError("Lr Decay != 0.0 not supported!") + super(Adagrad, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adagrad8bit(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=8, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') + raise ValueError("Lr Decay != 0.0 not supported!") assert block_wise - super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + super(Adagrad8bit, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adagrad32bit(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') - super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise ValueError("Lr Decay != 0.0 not supported!") + super(Adagrad32bit, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index ed1b9f0..3634971 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math @@ -8,29 +8,97 @@ import os import torch import torch.distributed as dist -from bitsandbytes.optim.optimizer import Optimizer2State + import bitsandbytes.functional as F +from bitsandbytes.optim.optimizer import Optimizer2State + class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) class AnalysisAdam(torch.optim.Optimizer): @@ -68,11 +136,15 @@ class AnalysisAdam(torch.optim.Optimizer): eps=1e-8, weight_decay=0, amsgrad=False, - bnb_analysis='dynamic-blockwise', - savedir=None + bnb_analysis="dynamic-blockwise", + savedir=None, ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, ) super(AnalysisAdam, self).__init__(params, defaults) self.analysis = bnb_analysis @@ -124,9 +196,15 @@ class AnalysisAdam(torch.optim.Optimizer): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device) + state["abserrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["relerrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -142,10 +220,12 @@ class AnalysisAdam(torch.optim.Optimizer): beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - e = state['abserrors'] - rele = state['relerrors'] - counts = state['counts'] + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) + e = state["abserrors"] + rele = state["relerrors"] + counts = state["counts"] if group["weight_decay"] != 0: p_data_fp32.add_( @@ -156,77 +236,91 @@ class AnalysisAdam(torch.optim.Optimizer): if amsgrad: max_exp_avg_sq = state["max_exp_avg_sq"] - # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) - update_fp32 = exp_avg/denom + update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000: + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): # embedding layer or too small - p_data_fp32 += -step_size*update_fp32 + p_data_fp32 += -step_size * update_fp32 else: - if self.analysis == 'dynamic-blockwise': + if self.analysis == "dynamic-blockwise": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize_blockwise(exp_avg, code=code1) state1 = F.dequantize_blockwise(C1, S1) C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) state2 = F.dequantize_blockwise(C2, S2) - elif self.analysis == 'dynamic': + elif self.analysis == "dynamic": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'linear': + elif self.analysis == "linear": code1 = F.create_linear_map(signed=True).to(p.device) code2 = F.create_linear_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'quantile': + elif self.analysis == "quantile": code1 = F.estimate_quantiles(exp_avg) code2 = F.estimate_quantiles(exp_avg_sq) C1 = F.quantize_no_absmax(exp_avg, code=code1) state1 = F.dequantize_no_absmax(C1, code1) C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) state2 = F.dequantize_no_absmax(C2, code2) - elif self.analysis == 'my-quantization-routine': + elif self.analysis == "my-quantization-routine": pass # 1. get code # 2. quantize # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f'Invalid analysis value: {self.analysis}!') + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) denom = state2.sqrt().add_(group["eps"]) - update_8bit = state1/denom + update_8bit = state1 / denom - abserr = torch.abs(update_8bit-update_fp32) - relerr = abserr/torch.abs(update_fp32+1e-6) + abserr = torch.abs(update_8bit - update_fp32) + relerr = abserr / torch.abs(update_fp32 + 1e-6) C1, C2 = C1.int(), C2.int() F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) - - p_data_fp32 += -step_size*update_fp32 + F.histogram_scatter_add_2d( + counts, C1.int(), C2.int(), torch.ones_like(abserr) + ) + p_data_fp32 += -step_size * update_fp32 if not dist.is_initialized() or dist.get_rank() == 0: - if self.savedir != '' and state['step'] % 100 == 0: - if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) - pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') - pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') - pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') + if self.savedir != "" and state["step"] % 100 == 0: + if not os.path.exists(self.savedir): + os.makedirs(self.savedir) + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) + pathe = os.path.join( + self.savedir, f"{p_id}_{shapestr}_abserr.pkl" + ) + pathrele = os.path.join( + self.savedir, f"{p_id}_{shapestr}_relerr.pkl" + ) + pathcounts = os.path.join( + self.savedir, f"{p_id}_{shapestr}_counts.pkl" + ) torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) @@ -234,6 +328,4 @@ class AnalysisAdam(torch.optim.Optimizer): if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) - - return loss diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index c4f0355..d0b3bde 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -1,27 +1,93 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State + class AdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class AdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) -class AdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) +class AdamW32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 58cc13d..8f365f7 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -1,28 +1,105 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State + class LAMB(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB, self).__init__('lamb', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) -class LAMB8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) -class LAMB32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) +class LAMB8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB8bit, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) +class LAMB32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB32bit, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 912520d..8a89fb0 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -1,60 +1,154 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch - from torch.optim import Optimizer + from bitsandbytes.optim.optimizer import Optimizer1State + class LARS(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) + super(LARS, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + class LARS8bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) + super(LARS8bit, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + class LARS32bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) + super(LARS32bit, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) class PytorchLARS(Optimizer): - def __init__(self, params, lr=0.01, momentum=0, dampening=0, - weight_decay=0, nesterov=False, max_unorm=0.02): + def __init__( + self, + params, + lr=0.01, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + max_unorm=0.02, + ): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + max_unorm=max_unorm, + ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening" + ) super(PytorchLARS, self).__init__(params, defaults) def __setstate__(self, state): super(PytorchLARS, self).__setstate__(state) for group in self.param_groups: - group.setdefault('nesterov', False) + group.setdefault("nesterov", False) @torch.no_grad() def step(self, closure=None): @@ -73,15 +167,16 @@ class PytorchLARS(Optimizer): params_with_grad = [] d_p_list = [] momentum_buffer_list = [] - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] - max_unorm = group['max_unorm'] - lr = group['lr'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + max_unorm = group["max_unorm"] + lr = group["lr"] - for p in group['params']: - if p.grad is None: continue + for p in group["params"]: + if p.grad is None: + continue state = self.state[p] d_p = p.grad @@ -89,16 +184,16 @@ class PytorchLARS(Optimizer): d_p = d_p.add(param, alpha=weight_decay) if momentum != 0: - buf = state.get('momentum_buffer', None) + buf = state.get("momentum_buffer", None) if buf is None: buf = torch.clone(d_p).detach() - state['momentum_buffer']= buf + state["momentum_buffer"] = buf else: buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: - update = d_p + buf*momentum + update = d_p + buf * momentum else: update = buf @@ -107,9 +202,9 @@ class PytorchLARS(Optimizer): assert p.dtype == torch.float32 pnorm = torch.norm(p.detach()) unorm = torch.norm(update) - if unorm > max_unorm*pnorm: - update_scale = max_unorm*pnorm/unorm + if unorm > max_unorm * pnorm: + update_scale = max_unorm * pnorm / unorm - p.add_(update, alpha=-lr*update_scale) + p.add_(update, alpha=-lr * update_scale) return loss diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 5a5bb1e..4fb30cd 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -1,13 +1,16 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import abc as container_abcs +from collections import defaultdict +from copy import deepcopy +from itertools import chain + import torch + import bitsandbytes.functional as F -from copy import deepcopy -from itertools import chain -from collections import defaultdict, abc as container_abcs class MockArgs(object): def __init__(self, initial_data): @@ -19,7 +22,7 @@ class GlobalOptimManager(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.pid2config = {} @@ -38,15 +41,19 @@ class GlobalOptimManager(object): def register_parameters(self, params): param_groups = list(params) if not isinstance(param_groups[0], dict): - param_groups = [{'params': param_groups}] + param_groups = [{"params": param_groups}] for group_index, group in enumerate(param_groups): - for p_index, p in enumerate(group['params']): + for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[id(p)] + self.index2config[(group_index, p_index)] = self.pid2config[ + id(p) + ] - def override_config(self, parameters, key=None, value=None, key_value_dict=None): - ''' + def override_config( + self, parameters, key=None, value=None, key_value_dict=None + ): + """ Overrides initial optimizer config for specific parameters. The key-values of the optimizer config for the input parameters are overidden @@ -63,7 +70,7 @@ class GlobalOptimManager(object): The value for the hyperparamters. key_value_dict : dict A dictionary with multiple key-values to override. - ''' + """ self.uses_config_override = True if isinstance(parameters, torch.nn.Parameter): parameters = [parameters] @@ -75,16 +82,16 @@ class GlobalOptimManager(object): if key_value_dict is not None: for p in parameters: - if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) - else: self.pid2config[id(p)] = key_value_dict + if id(p) in self.pid2config: + self.pid2config[id(p)].update(key_value_dict) + else: + self.pid2config[id(p)] = key_value_dict def register_module_override(self, module, param_name, config): self.module_weight_config_triple.append((module, param_name, config)) - class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): super(Optimizer8bit, self).__init__(params, defaults) self.initialized = False @@ -92,23 +99,32 @@ class Optimizer8bit(torch.optim.Optimizer): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = set( - ['qmap1', 'qmap2', - 'max1', 'max2', - 'new_max1', 'new_max2', - 'state1', 'state2', - 'gnorm_vec', 'absmax1', 'absmax2', - 'unorm_vec']) - - if optim_bits == 8: self.fill_qmap() + [ + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", + ] + ) + + if optim_bits == 8: + self.fill_qmap() def fill_qmap(self): - self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True) - self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False) + self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) + self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) def __setstate__(self, state): super(Optimizer8bit, self).__setstate__(state) - def load_state_dict(self, state_dict): r"""Loads the optimizer state. @@ -120,21 +136,29 @@ class Optimizer8bit(torch.optim.Optimizer): state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups - saved_groups = state_dict['param_groups'] + saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) + raise ValueError( + "loaded state dict has a different number of " + "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } def cast(param, value): r"""Make a deep copy of value, casting all tensors to device of param.""" @@ -161,7 +185,7 @@ class Optimizer8bit(torch.optim.Optimizer): # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) - for k, v in state_dict['state'].items(): + for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] state[param] = cast(param, v) @@ -170,15 +194,17 @@ class Optimizer8bit(torch.optim.Optimizer): # Update parameter groups, setting their 'params' value def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group + param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) + update_group(g, ng) for g, ng in zip(groups, saved_groups) + ] + self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group['params']): + for pindex, p in enumerate(group["params"]): if p in self.state: values = self.state[p] for k, v in values.items(): @@ -189,17 +215,23 @@ class Optimizer8bit(torch.optim.Optimizer): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + assert isinstance(pmodule, torch.Tensor) or isinstance( + pmodule, torch.Parameter + ) found = False for gindex, group in enumerate(self.param_groups): - if found: break - for pindex, p in enumerate(group['params']): - if found: break + if found: + break + for pindex, p in enumerate(group["params"]): + if found: + break if id(p) == id(pmodule): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] + self.mng.index2config[ + (gindex, pindex) + ] = self.mng.pid2config[id(p)] found = True @torch.no_grad() @@ -219,11 +251,11 @@ class Optimizer8bit(torch.optim.Optimizer): if not self.initialized: self.check_overrides() - self.to_gpu() # needed for fairseq pure fp16 training + self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group['params']): + for pindex, p in enumerate(group["params"]): if p.grad is None: continue state = self.state[p] @@ -236,58 +268,76 @@ class Optimizer8bit(torch.optim.Optimizer): def get_config(self, gindex, pindex, group): config = {} - config['betas'] = group['betas'] - config['eps'] = group['eps'] - config['weight_decay'] = group['weight_decay'] - config['lr'] = group['lr'] - config['optim_bits'] = self.args.optim_bits - config['min_8bit_size'] = self.args.min_8bit_size - config['percentile_clipping'] = self.args.percentile_clipping - config['block_wise'] = self.args.block_wise - config['max_unorm'] = self.args.max_unorm - config['skip_zeros'] = self.args.skip_zeros + config["betas"] = group["betas"] + config["eps"] = group["eps"] + config["weight_decay"] = group["weight_decay"] + config["lr"] = group["lr"] + config["optim_bits"] = self.args.optim_bits + config["min_8bit_size"] = self.args.min_8bit_size + config["percentile_clipping"] = self.args.percentile_clipping + config["block_wise"] = self.args.block_wise + config["max_unorm"] = self.args.max_unorm + config["skip_zeros"] = self.args.skip_zeros if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) return config def init_state(self, group, p, gindex, pindex): - raise NotImplementedError(f'init_state method needs to be overidden') + raise NotImplementedError(f"init_state method needs to be overidden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError(f'The update_step method needs to be overidden') + raise NotImplementedError( + f"The update_step method needs to be overidden" + ) + class Optimizer2State(Optimizer8bit): - def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, - skip_zeros=False): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if isinstance(betas, str): # format: '(beta1, beta2)' - betas = betas.replace('(', '').replace(')', '').strip().split(',') + betas = betas.replace("(", "").replace(")", "").strip().split(",") betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer2State, self).__init__(params, defaults, optim_bits) if args is None: args = {} - args['optim_bits'] = optim_bits - args['percentile_clipping'] = 100 - args['min_8bit_size'] = min_8bit_size - args['percentile_clipping'] = percentile_clipping - args['block_wise'] = block_wise - args['max_unorm'] = max_unorm - args['skip_zeros'] = skip_zeros + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros self.args = MockArgs(args) else: @@ -299,50 +349,93 @@ class Optimizer2State(Optimizer8bit): def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) - if config['optim_bits'] == 32: + if config["optim_bits"] == 32: dtype = torch.float32 - elif config['optim_bits'] == 8: + elif config["optim_bits"] == 8: dtype = torch.uint8 - else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) - if p.numel() < config['min_8bit_size']: dtype = torch.float32 + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 state = self.state[p] - state['step'] = 0 - - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) - state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state["step"] = 0 + + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) + state["state2"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) elif dtype == torch.uint8: - if state['step'] == 0: - if 'dynamic' not in self.name2qmap: self.fill_qmap() - self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) - self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) - - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap1'] = self.name2qmap['dynamic'] - - state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap2'] = self.name2qmap['udynamic'] - - if config['block_wise']: + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( + p.device + ) + + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap1"] = self.name2qmap["dynamic"] + + state["state2"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap2"] = self.name2qmap["udynamic"] + + if config["block_wise"]: n = p.numel() - blocks = n//2048 + blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + state["absmax2"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - - if config['percentile_clipping'] < 100: - state['gnorm_vec'] = torch.zeros((100,), device=p.device) - - if config['max_unorm'] > 0.0: - state['unorm_vec'] = torch.zeros((1,), device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) + + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() def update_step(self, group, p, gindex, pindex): @@ -351,63 +444,128 @@ class Optimizer2State(Optimizer8bit): config = self.get_config(gindex, pindex, group) - state['step'] += 1 - step = state['step'] + state["step"] += 1 + step = state["step"] - if config['percentile_clipping'] < 100: - current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) else: gnorm_scale = 1.0 - if state['state1'].dtype == torch.float: - F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], - state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros']) - - elif state['state1'].dtype == torch.uint8 and not config['block_wise']: - F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'], - config['weight_decay'], gnorm_scale=gnorm_scale, - unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + state["state2"], + config["betas"][1], + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) + + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["max1"], + state["max2"], + state["new_max1"], + state["new_max2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + unorm_vec=state["unorm_vec"] + if config["max_unorm"] > 0.0 + else None, + max_unorm=config["max_unorm"], + ) # swap maxes - state['max1'], state['new_max1'] = state['new_max1'], state['max1'] - state['max2'], state['new_max2'] = state['new_max2'], state['max2'] - elif state['state1'].dtype == torch.uint8 and config['block_wise']: - F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'], - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + state["max2"], state["new_max2"] = state["new_max2"], state["max2"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["absmax1"], + state["absmax2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) class Optimizer1State(Optimizer8bit): - def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, - weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, - skip_zeros=False): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.0), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer1State, self).__init__(params, defaults, optim_bits) if args is None: args = {} - args['optim_bits'] = optim_bits - args['percentile_clipping'] = 100 - args['min_8bit_size'] = min_8bit_size - args['percentile_clipping'] = percentile_clipping - args['block_wise'] = block_wise - args['max_unorm'] = max_unorm - args['skip_zeros'] = skip_zeros + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros self.args = MockArgs(args) else: @@ -419,43 +577,67 @@ class Optimizer1State(Optimizer8bit): def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) - if config['optim_bits'] == 32: + if config["optim_bits"] == 32: dtype = torch.float32 - elif config['optim_bits'] == 8: + elif config["optim_bits"] == 8: dtype = torch.uint8 - else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) - if p.numel() < config['min_8bit_size']: dtype = torch.float32 + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 state = self.state[p] - state['step'] = 0 - - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state["step"] = 0 + + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) elif dtype == torch.uint8: - if state['step'] == 0: - if 'dynamic' not in self.name2qmap: self.fill_qmap() - self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) - - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap1'] = self.name2qmap['dynamic'] - - if config['block_wise']: + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap1"] = self.name2qmap["dynamic"] + + if config["block_wise"]: n = p.numel() - blocks = n//2048 + blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - - if config['percentile_clipping'] < 100: - state['gnorm_vec'] = torch.zeros((100,), device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) - if config['max_unorm'] > 0.0: - state['unorm_vec'] = torch.zeros((1,), device=p.device) + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() def update_step(self, group, p, gindex, pindex): @@ -464,29 +646,77 @@ class Optimizer1State(Optimizer8bit): config = self.get_config(gindex, pindex, group) - state['step'] += 1 - step = state['step'] + state["step"] += 1 + step = state["step"] - if config['percentile_clipping'] < 100: - current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) else: gnorm_scale = 1.0 - if state['state1'].dtype == torch.float: - F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], - None, 0.0, config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], - skip_zeros=config['skip_zeros']) - - elif state['state1'].dtype == torch.uint8 and not config['block_wise']: - F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None, - config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) - - state['max1'], state['new_max1'] = state['new_max1'], state['max1'] - elif state['state1'].dtype == torch.uint8 and config['block_wise']: - F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + None, + 0.0, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) + + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["max1"], + None, + state["new_max1"], + None, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + ) + + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["absmax1"], + None, + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 0f1ffaa..7ddb12c 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -1,36 +1,115 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class RMSprop(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class RMSprop8bit(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop8bit, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class RMSprop32bit(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop32bit, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index 0529879..f7b8934 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -1,32 +1,99 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class SGD(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class SGD8bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD8bit, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class SGD32bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD32bit, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py new file mode 100644 index 0000000..4256a87 --- /dev/null +++ b/bitsandbytes/utils.py @@ -0,0 +1,32 @@ +import shlex +import subprocess +import sys +from typing import Tuple + + +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() + ) + + std_out, std_err = execute_and_return_decoded_std_streams(command_string) + return std_out, std_err + + +def print_stderr(s: str) -> None: + print(s, file=sys.stderr) + + +def warn_of_missing_prerequisite(s: str) -> None: + print_stderr("WARNING, missing pre-requisite: " + s) |