summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 09:32:47 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 09:32:47 -0700
commitea7c14f8ef64924f2d0ff80df3cdabf2c7299848 (patch)
tree3b9ec443a259cf36d87627a8e2cc7d13513f6a21 /bitsandbytes
parent3fd06fb6206f46b6d18fbb8a512da63832dea98b (diff)
reran black with linelength 80 for greater readability
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/__init__.py9
-rw-r--r--bitsandbytes/autograd/_functions.py45
-rw-r--r--bitsandbytes/cuda_setup.py45
-rw-r--r--bitsandbytes/functional.py128
-rw-r--r--bitsandbytes/nn/modules.py34
-rw-r--r--bitsandbytes/optim/adagrad.py12
-rw-r--r--bitsandbytes/optim/adam.py27
-rw-r--r--bitsandbytes/optim/lars.py20
-rw-r--r--bitsandbytes/optim/optimizer.py77
-rw-r--r--bitsandbytes/optim/rmsprop.py12
-rw-r--r--bitsandbytes/utils.py2
11 files changed, 319 insertions, 92 deletions
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 7ca017d..6e5b6ac 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -3,8 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
- matmul_cublas, mm_cublas)
+from .autograd._functions import (
+ MatmulLtState,
+ bmm_cublas,
+ matmul,
+ matmul_cublas,
+ mm_cublas,
+)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index a08b560..b56b2ee 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
- qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
+ qA, S2 = F.vectorwise_quant(
+ A, dim=dims, quant_type=quant_type
+ )
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
@@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
- igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
+ igrad_A,
+ S1,
+ S3.permute(permute_dim),
+ grad_output.dtype,
+ quant_type,
)
return grad_A, grad_B, None, None, None
@@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
- CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
+ A, threshold=state.threshold
+ )
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
@@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
- state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ state.CxB, state.SB = F.transform(
+ state.CB, to_order=formatB
+ )
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
@@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
- CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
+ (
+ CB,
+ state.CBt,
+ state.SCB,
+ state.SCBt,
+ coo_tensorB,
+ ) = F.double_quant(B)
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
@@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
- (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
+ (outliers * state.SCB.view(-1, 1) / 127.0)
+ .t()
+ .contiguous()
+ .half()
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
@@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
- assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
+ assert (
+ state.has_fp16_weights
+ ), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3:
- grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
+ grad_output = grad_output.view(
+ -1, grad_output.shape[-1]
+ ).contiguous()
grad_A = grad_B = None
@@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
def matmul(
- A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
+ A: tensor,
+ B: tensor,
+ out: tensor = None,
+ state: MatmulLtState = None,
+ threshold=0.0,
):
state = state or MatmulLtState()
if threshold > 0.0:
diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py
index 8cc2c03..6e37606 100644
--- a/bitsandbytes/cuda_setup.py
+++ b/bitsandbytes/cuda_setup.py
@@ -1,7 +1,7 @@
"""
-build is dependent on
-- compute capability
- - dependent on GPU family
+extract factors the build is dependent on:
+[X] compute capability
+ [ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
@@ -19,6 +19,8 @@ evaluation:
"""
import ctypes
+import shlex
+import subprocess
from os import environ as env
from pathlib import Path
from typing import Set, Union
@@ -26,10 +28,31 @@ from typing import Set, Union
from .utils import print_err, warn_of_missing_prerequisite
+def execute_and_return(command_string: str) -> Tuple[str, str]:
+ def _decode(subprocess_err_out_tuple):
+ return tuple(
+ to_decode.decode("UTF-8").strip()
+ for to_decode in subprocess_err_out_tuple
+ )
+
+ def execute_and_return_decoded_std_streams(command_string):
+ return _decode(
+ subprocess.Popen(
+ shlex.split(command_string),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ ).communicate()
+ )
+
+ std_out, std_err = execute_and_return_decoded_std_streams()
+ return std_out, std_err
+
+
def check_cuda_result(cuda, result_val):
if result_val != 0:
+ # TODO: undefined name 'error_str'
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
- print(f"Count not initialize CUDA - failure!")
+ print("Count not initialize CUDA - failure!")
raise Exception("CUDA exception!")
return result_val
@@ -53,7 +76,9 @@ def get_compute_capability():
result = ctypes.c_int()
device = ctypes.c_int()
+ # TODO: local variable 'context' is assigned to but never used
context = ctypes.c_void_p()
+ # TODO: local variable 'error_str' is assigned to but never used
error_str = ctypes.c_char_p()
result = check_cuda_result(cuda, cuda.cuInit(0))
@@ -61,7 +86,9 @@ def get_compute_capability():
result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = []
for i in range(nGpus.value):
- result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
+ result = check_cuda_result(
+ cuda, cuda.cuDeviceGet(ctypes.byref(device), i)
+ )
result = check_cuda_result(
cuda,
cuda.cuDeviceComputeCapability(
@@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
} - non_existent_directories
if len(cuda_runtime_libs) > 1:
- err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
+ err_msg = (
+ f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
+ )
raise FileNotFoundError(err_msg)
elif len(cuda_runtime_libs) < 1:
- err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
+ err_msg = (
+ f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
+ )
raise FileNotFoundError(err_msg)
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 2e86958..236ef39 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
- str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
- str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
- str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
- str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
+ str2optimizer32bit["momentum"] = (
+ lib.cmomentum32bit_g32,
+ lib.cmomentum32bit_g16,
+ )
+ str2optimizer32bit["rmsprop"] = (
+ lib.crmsprop32bit_g32,
+ lib.crmsprop32bit_g16,
+ )
+ str2optimizer32bit["adagrad"] = (
+ lib.cadagrad32bit_g32,
+ lib.cadagrad32bit_g16,
+ )
+ str2optimizer32bit["lars"] = (
+ lib.cmomentum32bit_g32,
+ lib.cmomentum32bit_g16,
+ )
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {}
- str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+ str2optimizer8bit["adam"] = (
+ lib.cadam_static_8bit_g32,
+ lib.cadam_static_8bit_g16,
+ )
str2optimizer8bit["momentum"] = (
lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16,
@@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16,
)
- str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+ str2optimizer8bit["lamb"] = (
+ lib.cadam_static_8bit_g32,
+ lib.cadam_static_8bit_g16,
+ )
str2optimizer8bit["lars"] = (
lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16,
@@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
if not signed:
additional_items = 2 * additional_items
for i in range(n):
- fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
+ fraction_items = (
+ 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
+ )
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist()
@@ -272,7 +292,13 @@ def get_transform_buffer(
def nvidia_transform(
- A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
+ A,
+ to_order,
+ from_order="row",
+ out=None,
+ transpose=False,
+ state=None,
+ ld=None,
):
if state is None:
state = (A.shape, from_order)
@@ -352,7 +378,11 @@ def estimate_quantiles(
def quantize_blockwise(
- A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None
+ A: Tensor,
+ code: Tensor = None,
+ absmax: Tensor = None,
+ rand=None,
+ out: Tensor = None,
) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
@@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
- lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
+ lib.cdequantize(
+ get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
+ )
return out
@@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
)
-def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
+def check_matmul(
+ A, B, out, transposed_A, transposed_B, expected_type=torch.int8
+):
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
@@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def igemm(
- A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
+ A: Tensor,
+ B: Tensor,
+ out: Tensor = None,
+ transposed_A=False,
+ transposed_B=False,
):
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None:
@@ -1193,7 +1231,11 @@ def igemm(
def batched_igemm(
- A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
+ A: Tensor,
+ B: Tensor,
+ out: Tensor = None,
+ transposed_A=False,
+ transposed_B=False,
):
if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError(
@@ -1392,9 +1434,13 @@ def mm_dequant(
if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
- new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
+ new_row_stats = torch.empty(
+ out_shape[0], dtype=torch.float32, device=A.device
+ )
if new_col_stats is None:
- new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
+ new_col_stats = torch.empty(
+ out_shape[1], dtype=torch.float32, device=A.device
+ )
assert (
new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}"
@@ -1440,13 +1486,13 @@ def get_colrow_absmax(
col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None:
- row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(
- -50000.0
- )
+ row_stats = torch.empty(
+ (rows,), dtype=torch.float32, device=device
+ ).fill_(-50000.0)
if col_stats is None:
- col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(
- -50000.0
- )
+ col_stats = torch.empty(
+ (cols,), dtype=torch.float32, device=device
+ ).fill_(-50000.0)
if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros(
@@ -1462,7 +1508,13 @@ def get_colrow_absmax(
prev_device = pre_call(A.device)
lib.cget_col_row_stats(
- ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrNnzrows,
+ ct.c_float(threshold),
+ rows,
+ cols,
)
post_call(prev_device)
@@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
- rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
+ rowptr = torch.zeros(
+ (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
+ )
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
return CSRSparseTensor(
@@ -1540,10 +1594,14 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
- colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
+ colptr = torch.zeros(
+ (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
+ )
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
- return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
+ return CSCSparseTensor(
+ cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
+ )
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
@@ -1568,7 +1626,9 @@ def double_quant(
rows = A.shape[0]
if row_stats is None or col_stats is None:
- row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
+ row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
+ A, threshold=threshold
+ )
if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
@@ -1663,7 +1723,13 @@ def get_special_format_str():
def transform(
- A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
+ A,
+ to_order,
+ from_order="row",
+ out=None,
+ transpose=False,
+ state=None,
+ ld=None,
):
if state is None:
state = (A.shape, from_order)
@@ -1716,7 +1782,9 @@ def transform(
def spmm_coo(cooA, B, out=None):
if out is None:
- out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
+ out = torch.empty(
+ (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
+ )
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
@@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"
- out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
+ out = torch.zeros(
+ (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
+ )
idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 9ce3ac8..454dba5 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -2,8 +2,19 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
- Tuple, TypeVar, Union, overload)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import torch
import torch.nn.functional as F
@@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
class Int8Params(torch.nn.Parameter):
def __new__(
- cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
+ cls,
+ data=None,
+ requires_grad=True,
+ has_fp16_weights=False,
+ CB=None,
+ SCB=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
@@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
return self.cuda(device)
else:
new_param = Int8Params(
- super().to(device=device, dtype=dtype, non_blocking=non_blocking),
+ super().to(
+ device=device, dtype=dtype, non_blocking=non_blocking
+ ),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
@@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
threshold=0.0,
index=None,
):
- super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+ super(Linear8bitLt, self).__init__(
+ input_features, output_features, bias
+ )
self.state = bnb.MatmulLtState()
self.index = index
@@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
- self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+ self.weight = Int8Params(
+ self.weight.data, has_fp16_weights=has_fp16_weights
+ )
def init_8bit_state(self):
self.state.CB = self.weight.CB
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
index 43e3973..7e2f566 100644
--- a/bitsandbytes/optim/adagrad.py
+++ b/bitsandbytes/optim/adagrad.py
@@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
@@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
@@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index 5cfaa28..3634971 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
savedir=None,
):
defaults = dict(
- lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
)
super(AnalysisAdam, self).__init__(params, defaults)
self.analysis = bnb_analysis
@@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
- state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["counts"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
- step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ step_size = (
+ group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ )
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
@@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
+ if (
+ p_data_fp32.numel() <= 8192
+ or p_data_fp32.numel() > 50000 * 1000
+ ):
# embedding layer or too small
p_data_fp32 += -step_size * update_fp32
else:
@@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f"Invalid analysis value: {self.analysis}!")
+ raise ValueError(
+ f"Invalid analysis value: {self.analysis}!"
+ )
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom
@@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
- shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
+ shapestr = "_".join(
+ [str(dim) for dim in p_data_fp32.shape]
+ )
pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index c6cf5c6..8a89fb0 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -24,7 +24,9 @@ class LARS(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS, self).__init__(
"lars",
params,
@@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS8bit, self).__init__(
"lars",
params,
@@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS32bit, self).__init__(
"lars",
params,
@@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(
lr=lr,
@@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ raise ValueError(
+ "Nesterov momentum requires a momentum and zero dampening"
+ )
super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state):
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index b942e34..4fb30cd 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -46,9 +46,13 @@ class GlobalOptimManager(object):
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
- self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+ self.index2config[(group_index, p_index)] = self.pid2config[
+ id(p)
+ ]
- def override_config(self, parameters, key=None, value=None, key_value_dict=None):
+ def override_config(
+ self, parameters, key=None, value=None, key_value_dict=None
+ ):
"""
Overrides initial optimizer config for specific parameters.
@@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(groups) != len(saved_groups):
raise ValueError(
- "loaded state dict has a different number of " "parameter groups"
+ "loaded state dict has a different number of "
+ "parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
@@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"]
return new_group
- param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ param_groups = [
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)
+ ]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
@@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
- id(p)
- ]
+ self.mng.index2config[
+ (gindex, pindex)
+ ] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
@@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f"The update_step method needs to be overidden")
+ raise NotImplementedError(
+ f"The update_step method needs to be overidden"
+ )
class Optimizer2State(Optimizer8bit):
@@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
@@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
@@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
- self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
- self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
+ p.device
+ )
state["state1"] = torch.zeros_like(
p,
@@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device
)
else:
- state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
- state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
@@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
- unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ unorm_vec=state["unorm_vec"]
+ if config["max_unorm"] > 0.0
+ else None,
max_unorm=config["max_unorm"],
)
@@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
@@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
@@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
- self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
state["state1"] = torch.zeros_like(
p,
@@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device
)
else:
- state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index 679f783..7ddb12c 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
block_wise=True,
):
if alpha == 0:
- raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__(
@@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
block_wise=True,
):
if alpha == 0:
- raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__(
@@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
):
if alpha == 0:
- raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__(
diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py
index 6797407..8a9fc0e 100644
--- a/bitsandbytes/utils.py
+++ b/bitsandbytes/utils.py
@@ -1,6 +1,6 @@
+import sys
import shlex
import subprocess
-import sys
def execute_and_return(command_string: str) -> Tuple[str, str]: