summaryrefslogtreecommitdiff
path: root/bitsandbytes/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r--bitsandbytes/functional.py128
1 files changed, 99 insertions, 29 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 2e86958..236ef39 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
- str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
- str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
- str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
- str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
+ str2optimizer32bit["momentum"] = (
+ lib.cmomentum32bit_g32,
+ lib.cmomentum32bit_g16,
+ )
+ str2optimizer32bit["rmsprop"] = (
+ lib.crmsprop32bit_g32,
+ lib.crmsprop32bit_g16,
+ )
+ str2optimizer32bit["adagrad"] = (
+ lib.cadagrad32bit_g32,
+ lib.cadagrad32bit_g16,
+ )
+ str2optimizer32bit["lars"] = (
+ lib.cmomentum32bit_g32,
+ lib.cmomentum32bit_g16,
+ )
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {}
- str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+ str2optimizer8bit["adam"] = (
+ lib.cadam_static_8bit_g32,
+ lib.cadam_static_8bit_g16,
+ )
str2optimizer8bit["momentum"] = (
lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16,
@@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16,
)
- str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
+ str2optimizer8bit["lamb"] = (
+ lib.cadam_static_8bit_g32,
+ lib.cadam_static_8bit_g16,
+ )
str2optimizer8bit["lars"] = (
lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16,
@@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
if not signed:
additional_items = 2 * additional_items
for i in range(n):
- fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
+ fraction_items = (
+ 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
+ )
boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist()
@@ -272,7 +292,13 @@ def get_transform_buffer(
def nvidia_transform(
- A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
+ A,
+ to_order,
+ from_order="row",
+ out=None,
+ transpose=False,
+ state=None,
+ ld=None,
):
if state is None:
state = (A.shape, from_order)
@@ -352,7 +378,11 @@ def estimate_quantiles(
def quantize_blockwise(
- A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None
+ A: Tensor,
+ code: Tensor = None,
+ absmax: Tensor = None,
+ rand=None,
+ out: Tensor = None,
) -> Tensor:
"""
Quantize tensor A in blocks of size 4096 values.
@@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
"""
if out is None:
out = torch.zeros_like(A, dtype=torch.float32)
- lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
+ lib.cdequantize(
+ get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
+ )
return out
@@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
)
-def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
+def check_matmul(
+ A, B, out, transposed_A, transposed_B, expected_type=torch.int8
+):
if not torch.cuda.is_initialized():
torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type:
@@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def igemm(
- A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
+ A: Tensor,
+ B: Tensor,
+ out: Tensor = None,
+ transposed_A=False,
+ transposed_B=False,
):
sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None:
@@ -1193,7 +1231,11 @@ def igemm(
def batched_igemm(
- A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
+ A: Tensor,
+ B: Tensor,
+ out: Tensor = None,
+ transposed_A=False,
+ transposed_B=False,
):
if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError(
@@ -1392,9 +1434,13 @@ def mm_dequant(
if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None:
- new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
+ new_row_stats = torch.empty(
+ out_shape[0], dtype=torch.float32, device=A.device
+ )
if new_col_stats is None:
- new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
+ new_col_stats = torch.empty(
+ out_shape[1], dtype=torch.float32, device=A.device
+ )
assert (
new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}"
@@ -1440,13 +1486,13 @@ def get_colrow_absmax(
col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None:
- row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(
- -50000.0
- )
+ row_stats = torch.empty(
+ (rows,), dtype=torch.float32, device=device
+ ).fill_(-50000.0)
if col_stats is None:
- col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(
- -50000.0
- )
+ col_stats = torch.empty(
+ (cols,), dtype=torch.float32, device=device
+ ).fill_(-50000.0)
if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros(
@@ -1462,7 +1508,13 @@ def get_colrow_absmax(
prev_device = pre_call(A.device)
lib.cget_col_row_stats(
- ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols
+ ptrA,
+ ptrRowStats,
+ ptrColStats,
+ ptrNnzrows,
+ ct.c_float(threshold),
+ rows,
+ cols,
)
post_call(prev_device)
@@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1)
- rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
+ rowptr = torch.zeros(
+ (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
+ )
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0)
return CSRSparseTensor(
@@ -1540,10 +1594,14 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1)
- colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
+ colptr = torch.zeros(
+ (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
+ )
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0)
- return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
+ return CSCSparseTensor(
+ cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
+ )
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
@@ -1568,7 +1626,9 @@ def double_quant(
rows = A.shape[0]
if row_stats is None or col_stats is None:
- row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
+ row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
+ A, threshold=threshold
+ )
if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
@@ -1663,7 +1723,13 @@ def get_special_format_str():
def transform(
- A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
+ A,
+ to_order,
+ from_order="row",
+ out=None,
+ transpose=False,
+ state=None,
+ ld=None,
):
if state is None:
state = (A.shape, from_order)
@@ -1716,7 +1782,9 @@ def transform(
def spmm_coo(cooA, B, out=None):
if out is None:
- out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
+ out = torch.empty(
+ (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
+ )
nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
@@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda"
- out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
+ out = torch.zeros(
+ (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
+ )
idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0])