From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- bitsandbytes/functional.py | 128 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 99 insertions(+), 29 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2e86958..236ef39 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) - str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) - str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) - str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["momentum"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) + str2optimizer32bit["rmsprop"] = ( + lib.crmsprop32bit_g32, + lib.crmsprop32bit_g16, + ) + str2optimizer32bit["adagrad"] = ( + lib.cadagrad32bit_g32, + lib.cadagrad32bit_g16, + ) + str2optimizer32bit["lars"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} - str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["adam"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) str2optimizer8bit["momentum"] = ( lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16, @@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16, ) - str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["lamb"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) str2optimizer8bit["lars"] = ( lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16, @@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7): if not signed: additional_items = 2 * additional_items for i in range(n): - fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + fraction_items = ( + 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(n - 1) + i)) * means).tolist() @@ -272,7 +292,13 @@ def get_transform_buffer( def nvidia_transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): if state is None: state = (A.shape, from_order) @@ -352,7 +378,11 @@ def estimate_quantiles( def quantize_blockwise( - A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None + A: Tensor, + code: Tensor = None, + absmax: Tensor = None, + rand=None, + out: Tensor = None, ) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: """ if out is None: out = torch.zeros_like(A, dtype=torch.float32) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + lib.cdequantize( + get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()) + ) return out @@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d( ) -def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): +def check_matmul( + A, B, out, transposed_A, transposed_B, expected_type=torch.int8 +): if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: @@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 def igemm( - A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, ): sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: @@ -1193,7 +1231,11 @@ def igemm( def batched_igemm( - A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: raise ValueError( @@ -1392,9 +1434,13 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) assert ( new_row_stats.shape[0] == row_stats.shape[0] ), f"{new_row_stats.shape} vs {row_stats.shape}" @@ -1440,13 +1486,13 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros( @@ -1462,7 +1508,13 @@ def get_colrow_absmax( prev_device = pre_call(A.device) lib.cget_col_row_stats( - ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols + ptrA, + ptrRowStats, + ptrColStats, + ptrNnzrows, + ct.c_float(threshold), + rows, + cols, ) post_call(prev_device) @@ -1526,7 +1578,9 @@ class CSCSparseTensor(object): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) return CSRSparseTensor( @@ -1540,10 +1594,14 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -1568,7 +1626,9 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -1663,7 +1723,13 @@ def get_special_format_str(): def transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): if state is None: state = (A.shape, from_order) @@ -1716,7 +1782,9 @@ def transform( def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) -- cgit v1.2.3