From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- bitsandbytes/__init__.py | 9 ++- bitsandbytes/autograd/_functions.py | 45 ++++++++++--- bitsandbytes/cuda_setup.py | 45 +++++++++++-- bitsandbytes/functional.py | 128 ++++++++++++++++++++++++++++-------- bitsandbytes/nn/modules.py | 34 ++++++++-- bitsandbytes/optim/adagrad.py | 12 +++- bitsandbytes/optim/adam.py | 27 ++++++-- bitsandbytes/optim/lars.py | 20 ++++-- bitsandbytes/optim/optimizer.py | 77 ++++++++++++++++------ bitsandbytes/optim/rmsprop.py | 12 +++- bitsandbytes/utils.py | 2 +- 11 files changed, 319 insertions(+), 92 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 7ca017d..6e5b6ac 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,8 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .autograd._functions import (MatmulLtState, bmm_cublas, matmul, - matmul_cublas, mm_cublas) +from .autograd._functions import ( + MatmulLtState, + bmm_cublas, + matmul, + matmul_cublas, + mm_cublas, +) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a08b560..b56b2ee 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function): qgrad_output, S1 = F.vectorwise_quant( grad_output, dim=dims, quant_type=quant_type ) - qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant( + A, dim=dims, quant_type=quant_type + ) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function): qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( - igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type + igrad_A, + S1, + S3.permute(permute_dim), + grad_output.dtype, + quant_type, ) return grad_A, grad_B, None, None, None @@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A, threshold=state.threshold + ) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function): if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + state.CxB, state.SB = F.transform( + state.CB, to_order=formatB + ) # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB @@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function): if (state.is_training and not has_grad) or state.CxB is None: state.reset_grads() - CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B) + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False @@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half() + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .half() ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert state.has_fp16_weights, "Backprop only supported for fp16 weights." + assert ( + state.has_fp16_weights + ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() + grad_output = grad_output.view( + -1, grad_output.shape[-1] + ).contiguous() grad_A = grad_B = None @@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply def matmul( - A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0 + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, ): state = state or MatmulLtState() if threshold > 0.0: diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py index 8cc2c03..6e37606 100644 --- a/bitsandbytes/cuda_setup.py +++ b/bitsandbytes/cuda_setup.py @@ -1,7 +1,7 @@ """ -build is dependent on -- compute capability - - dependent on GPU family +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? - CUDA version - Software: - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) @@ -19,6 +19,8 @@ evaluation: """ import ctypes +import shlex +import subprocess from os import environ as env from pathlib import Path from typing import Set, Union @@ -26,10 +28,31 @@ from typing import Set, Union from .utils import print_err, warn_of_missing_prerequisite +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() + ) + + std_out, std_err = execute_and_return_decoded_std_streams() + return std_out, std_err + + def check_cuda_result(cuda, result_val): if result_val != 0: + # TODO: undefined name 'error_str' cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) - print(f"Count not initialize CUDA - failure!") + print("Count not initialize CUDA - failure!") raise Exception("CUDA exception!") return result_val @@ -53,7 +76,9 @@ def get_compute_capability(): result = ctypes.c_int() device = ctypes.c_int() + # TODO: local variable 'context' is assigned to but never used context = ctypes.c_void_p() + # TODO: local variable 'error_str' is assigned to but never used error_str = ctypes.c_char_p() result = check_cuda_result(cuda, cuda.cuInit(0)) @@ -61,7 +86,9 @@ def get_compute_capability(): result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) ccs = [] for i in range(nGpus.value): - result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) + result = check_cuda_result( + cuda, cuda.cuDeviceGet(ctypes.byref(device), i) + ) result = check_cuda_result( cuda, cuda.cuDeviceComputeCapability( @@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path( } - non_existent_directories if len(cuda_runtime_libs) > 1: - err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + err_msg = ( + f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + ) raise FileNotFoundError(err_msg) elif len(cuda_runtime_libs) < 1: - err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + err_msg = ( + f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + ) raise FileNotFoundError(err_msg) single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2e86958..236ef39 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) - str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) - str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) - str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["momentum"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) + str2optimizer32bit["rmsprop"] = ( + lib.crmsprop32bit_g32, + lib.crmsprop32bit_g16, + ) + str2optimizer32bit["adagrad"] = ( + lib.cadagrad32bit_g32, + lib.cadagrad32bit_g16, + ) + str2optimizer32bit["lars"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} - str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["adam"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) str2optimizer8bit["momentum"] = ( lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16, @@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16, ) - str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["lamb"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) str2optimizer8bit["lars"] = ( lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16, @@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7): if not signed: additional_items = 2 * additional_items for i in range(n): - fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + fraction_items = ( + 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(n - 1) + i)) * means).tolist() @@ -272,7 +292,13 @@ def get_transform_buffer( def nvidia_transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): if state is None: state = (A.shape, from_order) @@ -352,7 +378,11 @@ def estimate_quantiles( def quantize_blockwise( - A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None + A: Tensor, + code: Tensor = None, + absmax: Tensor = None, + rand=None, + out: Tensor = None, ) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: """ if out is None: out = torch.zeros_like(A, dtype=torch.float32) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + lib.cdequantize( + get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()) + ) return out @@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d( ) -def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): +def check_matmul( + A, B, out, transposed_A, transposed_B, expected_type=torch.int8 +): if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: @@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 def igemm( - A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, ): sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: @@ -1193,7 +1231,11 @@ def igemm( def batched_igemm( - A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: raise ValueError( @@ -1392,9 +1434,13 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) assert ( new_row_stats.shape[0] == row_stats.shape[0] ), f"{new_row_stats.shape} vs {row_stats.shape}" @@ -1440,13 +1486,13 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros( @@ -1462,7 +1508,13 @@ def get_colrow_absmax( prev_device = pre_call(A.device) lib.cget_col_row_stats( - ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols + ptrA, + ptrRowStats, + ptrColStats, + ptrNnzrows, + ct.c_float(threshold), + rows, + cols, ) post_call(prev_device) @@ -1526,7 +1578,9 @@ class CSCSparseTensor(object): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) return CSRSparseTensor( @@ -1540,10 +1594,14 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -1568,7 +1626,9 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -1663,7 +1723,13 @@ def get_special_format_str(): def transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): if state is None: state = (A.shape, from_order) @@ -1716,7 +1782,9 @@ def transform( def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9ce3ac8..454dba5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,8 +2,19 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set, - Tuple, TypeVar, Union, overload) +from typing import ( + Any, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Set, + Tuple, + TypeVar, + Union, + overload, +) import torch import torch.nn.functional as F @@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding): class Int8Params(torch.nn.Parameter): def __new__( - cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None + cls, + data=None, + requires_grad=True, + has_fp16_weights=False, + CB=None, + SCB=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None @@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter): return self.cuda(device) else: new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), + super().to( + device=device, dtype=dtype, non_blocking=non_blocking + ), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear): threshold=0.0, index=None, ): - super(Linear8bitLt, self).__init__(input_features, output_features, bias) + super(Linear8bitLt, self).__init__( + input_features, output_features, bias + ) self.state = bnb.MatmulLtState() self.index = index @@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 43e3973..7e2f566 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -23,7 +23,9 @@ class Adagrad(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: @@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: @@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 5cfaa28..3634971 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer): savedir=None, ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, ) super(AnalysisAdam, self).__init__(params, defaults) self.analysis = bnb_analysis @@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer): state["relerrors"] = torch.zeros( (256, 256), device=p_data_fp32.device ) - state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer): beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] @@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f"Invalid analysis value: {self.analysis}!") + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) pathe = os.path.join( self.savedir, f"{p_id}_{shapestr}_abserr.pkl" ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index c6cf5c6..8a89fb0 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -24,7 +24,9 @@ class LARS(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS, self).__init__( "lars", params, @@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS8bit, self).__init__( "lars", params, @@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS32bit, self).__init__( "lars", params, @@ -121,7 +127,9 @@ class PytorchLARS(Optimizer): if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict( lr=lr, @@ -132,7 +140,9 @@ class PytorchLARS(Optimizer): max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening" + ) super(PytorchLARS, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index b942e34..4fb30cd 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -46,9 +46,13 @@ class GlobalOptimManager(object): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[id(p)] + self.index2config[(group_index, p_index)] = self.pid2config[ + id(p) + ] - def override_config(self, parameters, key=None, value=None, key_value_dict=None): + def override_config( + self, parameters, key=None, value=None, key_value_dict=None + ): """ Overrides initial optimizer config for specific parameters. @@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer): if len(groups) != len(saved_groups): raise ValueError( - "loaded state dict has a different number of " "parameter groups" + "loaded state dict has a different number of " + "parameter groups" ) param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) @@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer): new_group["params"] = group["params"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups) + ] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[ - id(p) - ] + self.mng.index2config[ + (gindex, pindex) + ] = self.mng.pid2config[id(p)] found = True @torch.no_grad() @@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer): raise NotImplementedError(f"init_state method needs to be overidden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError(f"The update_step method needs to be overidden") + raise NotImplementedError( + f"The update_step method needs to be overidden" + ) class Optimizer2State(Optimizer8bit): @@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit): betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer2State, self).__init__(params, defaults, optim_bits) @@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = torch.zeros_like( p, memory_format=torch.preserve_format, @@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit): if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( + p.device + ) state["state1"] = torch.zeros_like( p, @@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit): (blocks,), dtype=torch.float32, device=p.device ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max1"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) - state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max2"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) @@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + unorm_vec=state["unorm_vec"] + if config["max_unorm"] > 0.0 + else None, max_unorm=config["max_unorm"], ) @@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit): raise ValueError("Invalid epsilon value: {}".format(eps)) for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer1State, self).__init__(params, defaults, optim_bits) @@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = torch.zeros_like( p, memory_format=torch.preserve_format, @@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit): if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) state["state1"] = torch.zeros_like( p, @@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit): (blocks,), dtype=torch.float32, device=p.device ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max1"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 679f783..7ddb12c 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -22,7 +22,9 @@ class RMSprop(Optimizer1State): block_wise=True, ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop, self).__init__( @@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State): block_wise=True, ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop8bit, self).__init__( @@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State): ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop32bit, self).__init__( diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 6797407..8a9fc0e 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,6 +1,6 @@ +import sys import shlex import subprocess -import sys def execute_and_return(command_string: str) -> Tuple[str, str]: -- cgit v1.2.3