From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- bitsandbytes/functional.py | 1395 +++++++++++++++++++++++++++++++------------- 1 file changed, 981 insertions(+), 414 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ac85f88..2e86958 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct import random @@ -9,47 +9,68 @@ from typing import Tuple import torch from torch import Tensor -from .cextension import lib, COMPILED_WITH_CUDA +from .cextension import COMPILED_WITH_CUDA, lib name2qmap = {} if COMPILED_WITH_CUDA: - ''' C FUNCTIONS FOR OPTIMIZERS ''' + """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) - str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) - str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) - str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) + str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) + str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} - str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) - str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16) - str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16) - str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) - str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16) + str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["momentum"] = ( + lib.cmomentum_static_8bit_g32, + lib.cmomentum_static_8bit_g16, + ) + str2optimizer8bit["rmsprop"] = ( + lib.crmsprop_static_8bit_g32, + lib.crmsprop_static_8bit_g16, + ) + str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["lars"] = ( + lib.cmomentum_static_8bit_g32, + lib.cmomentum_static_8bit_g16, + ) str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16) + str2optimizer8bit_blockwise["adam"] = ( + lib.cadam_8bit_blockwise_fp32, + lib.cadam_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["momentum"] = ( + lib.cmomentum_8bit_blockwise_fp32, + lib.cmomentum_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["rmsprop"] = ( + lib.crmsprop_8bit_blockwise_fp32, + lib.crmsprop_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["adagrad"] = ( + lib.cadagrad_8bit_blockwise_fp32, + lib.cadagrad_8bit_blockwise_fp16, + ) class CUBLAS_Context(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = {} - #prev_device = torch.cuda.current_device() - #for i in range(torch.cuda.device_count()): + # prev_device = torch.cuda.current_device() + # for i in range(torch.cuda.device_count()): # torch.cuda.set_device(torch.device('cuda', i)) # self.context.append(ct.c_void_p(lib.get_context())) - #torch.cuda.set_device(prev_device) + # torch.cuda.set_device(prev_device) @classmethod def get_instance(cls): @@ -66,11 +87,12 @@ class CUBLAS_Context(object): torch.cuda.set_device(prev_device) return self.context[device.index] + class Cusparse_Context(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = ct.c_void_p(lib.get_cusparse()) @@ -82,14 +104,16 @@ class Cusparse_Context(object): cls._instance.initialize() return cls._instance + def create_linear_map(signed=True): if signed: return torch.linspace(-1.0, 1.0, 256) else: return torch.linspace(0.0, 1.0, 256) + def create_dynamic_map(signed=True, n=7): - ''' + """ Creates the dynamic quantiztion map. The dynamic data type is made up of a dynamic exponent and @@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7): For more details see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] - ''' + """ data = [] # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - additional_items = 2**(7-n)-1 - if not signed: additional_items = 2*additional_items + additional_items = 2 ** (7 - n) - 1 + if not signed: + additional_items = 2 * additional_items for i in range(n): - fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1 + fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 boundaries = torch.linspace(0.1, 1, fraction_items) - means = (boundaries[:-1]+boundaries[1:])/2.0 - data += ((10**(-(n-1)+i))*means).tolist() + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(n - 1) + i)) * means).tolist() if signed: - data += (-(10**(-(n-1)+i))*means).tolist() + data += (-(10 ** (-(n - 1) + i)) * means).tolist() if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items+1) - means = (boundaries[:-1]+boundaries[1:])/2.0 - data += ((10**(-(n-1)+i))*means).tolist() + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(n - 1) + i)) * means).tolist() if signed: - data += (-(10**(-(n-1)+i))*means).tolist() + data += (-(10 ** (-(n - 1) + i)) * means).tolist() data.append(0) data.append(1.0) data.sort() return Tensor(data) + def get_special_format_str(): major, minor = torch.cuda.get_device_capability() if major < 7: - print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + print( + f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!" + ) assert major >= 7 - if major == 7: return 'col_turing' - elif major == 8: return 'col_ampere' - else: return 'col_turing' + if major == 7: + return "col_turing" + elif major == 8: + return "col_ampere" + else: + return "col_turing" + def get_ptr(A: Tensor) -> ct.c_void_p: - ''' + """ Get the ctypes pointer from a PyTorch Tensor. Parameters @@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p: Returns ------- ctypes.c_void_p - ''' - if A is None: return None - else: return ct.c_void_p(A.data.storage().data_ptr()) + """ + if A is None: + return None + else: + return ct.c_void_p(A.data.storage().data_ptr()) + def pre_call(device): prev_device = torch.cuda.current_device() torch.cuda.set_device(device) return prev_device + def post_call(prev_device): torch.cuda.set_device(prev_device) + def get_transform_func(dtype, orderA, orderOut, transpose=False): name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' if not hasattr(lib, name): print(name) - raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}') + raise ValueError( + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + ) else: return getattr(lib, name) + class GlobalData(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.data = {} @@ -190,15 +230,17 @@ class GlobalData(object): return cls._instance -def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False): - #init_func = torch.empty +def get_transform_buffer( + shape, dtype, device, to_order, from_order="row", transpose=False +): + # init_func = torch.empty init_func = torch.zeros dims = len(shape) if dims == 2: rows = shape[0] elif dims == 3: - rows = shape[0]*shape[1] + rows = shape[0] * shape[1] cols = shape[-1] state = (shape, to_order) @@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans cols = tmp state = (shape[::-1], to_order) - if to_order == 'row' or to_order == 'col': + if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state - elif to_order == 'col32': + elif to_order == "col32": # blocks of 32 columns (padded) - cols = 32*((cols+31)//32) + cols = 32 * ((cols + 31) // 32) return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == 'col_turing': + elif to_order == "col_turing": # blocks of 32 columns and 8 rows - cols = 32*((cols+31)//32) - rows = 8*((rows+7)//8) + cols = 32 * ((cols + 31) // 32) + rows = 8 * ((rows + 7) // 8) return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == 'col_ampere': + elif to_order == "col_ampere": # blocks of 32 columns and 32 rows - cols = 32*((cols+31)//32) - rows = 32*((rows+31)//32) + cols = 32 * ((cols + 31) // 32) + rows = 32 * ((rows + 31) // 32) return init_func((rows, cols), dtype=dtype, device=device), state else: - raise NotImplementedError(f'To_order not supported: {to_order}') + raise NotImplementedError(f"To_order not supported: {to_order}") + -def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) - else: new_state = (state[1], to_order) +def nvidia_transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1] + ) + else: + new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) shape = state[0] @@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s elif ld is not None: n = math.prod(shape) dim1 = math.prod([shape[i] for i in ld]) - dim2 = ct.c_int32(n//dim1) + dim2 = ct.c_int32(n // dim1) dim1 = ct.c_int32(dim1) else: - dim1 = ct.c_int32(shape[0]*shape[1]) + dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) ptr = CUBLAS_Context.get_instance().get_context(A.device) @@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s ptrOut = get_ptr(out) func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) - return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor: - ''' + +def estimate_quantiles( + A: Tensor, out: Tensor = None, offset: float = 1 / 512 +) -> Tensor: + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + """ + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp32( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp16( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) else: - raise NotImplementedError(f'Not supported data type {A.dtype}') + raise NotImplementedError(f"Not supported data type {A.dtype}") return out -def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor: - ''' + +def quantize_blockwise( + A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None +) -> Tensor: + """ Quantize tensor A in blocks of size 4096 values. Quantizes tensor A by dividing it into blocks of 4096 values. @@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N The 8-bit tensor. tuple(torch.Tensor, torch.Tensor): The quantization state to undo the quantization. - ''' + """ if code is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) if absmax is None: n = A.numel() num_blocks = 4096 - blocks = n//num_blocks + blocks = n // num_blocks blocks += 1 if n % num_blocks > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) - - if A.device.type != 'cpu': + if A.device.type != "cpu": if rand is not None: assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + get_ptr(rand), + ct.c_int32(rand_offset), + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + get_ptr(rand), + ct.c_int32(rand_offset), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: # cpu assert rand is None - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) return out, (absmax, code) -def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, - absmax: Tensor=None, code: Tensor=None, out: Tensor=None, - blocksize: int=4096) -> Tensor: - ''' + +def dequantize_blockwise( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, + blocksize: int = 4096, +) -> Tensor: + """ Dequantizes blockwise quantized values. Dequantizes the tensor A with maximum absolute values absmax in @@ -374,7 +480,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, A : torch.Tensor The input 8-bit tensor. quant_state : tuple(torch.Tensor, torch.Tensor) - Tuple of code and absmax values. + Tuple of code and absmax values. absmax : torch.Tensor The absmax values. code : torch.Tensor @@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, ------- torch.Tensor: Dequantized tensor (default: float32) - ''' + """ assert quant_state is not None or absmax is not None if code is None and quant_state is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) - if quant_state is None: quant_state = (absmax, code) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) + if quant_state is None: + quant_state = (absmax, code) if blocksize not in [2048, 4096]: - raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" + ) - if A.device.type != 'cpu': + if A.device.type != "cpu": if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: - lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel())) - + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(A.numel()), + ) return out -def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor: +def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) absmax = torch.abs(A).max() - inp = A/absmax + inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) -def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor: + +def dequantize( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, +) -> Tensor: assert quant_state is not None or absmax is not None if code is None and quant_state is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) - if quant_state is None: quant_state = (absmax, code) + if quant_state is None: + quant_state = (absmax, code) out = dequantize_no_absmax(A, quant_state[1], out) - return out*quant_state[0] + return out * quant_state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: - ''' + +def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: ------- torch.Tensor: Quantized 8-bit tensor. - ''' - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + """ + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: - ''' + +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: ------- torch.Tensor: 32-bit output tensor. - ''' - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + """ + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out -def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor, - beta1: float, eps: float, step: int, lr: float, - state2: Tensor=None, beta2: float=0.0, - weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None: - ''' + +def optimizer_update_32bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Tensor = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + """ Performs an inplace optimizer update with one or two optimizer states. Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. @@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten The maximum update norm relative to the weight norm. skip_zeros : bool Whether to skip zero-valued gradients or not (default: False). - ''' + """ param_norm = 0.0 if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) if optimizer_name not in str2optimizer32bit: - raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}') + raise NotImplementedError( + f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' + ) if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), - ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer32bit[optimizer_name][0]( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), - ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer32bit[optimizer_name][1]( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') - -def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, - beta1: float, beta2: float, eps: float, - step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor, - weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: - ''' + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) + + +def optimizer_update_8bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + max1: Tensor, + max2: Tensor, + new_max1: Tensor, + new_max2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, +) -> None: + """ Performs an inplace Adam update. Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. @@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten The tensor for the update norm. max_unorm : float The maximum update norm relative to the weight norm. - ''' + """ param_norm = 0.0 if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), - get_ptr(qmap1), get_ptr(qmap2), - get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), - ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + str2optimizer8bit[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), - get_ptr(qmap1), get_ptr(qmap2), - get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), - ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + str2optimizer8bit[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') - - -def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, - beta1: float, beta2: float, eps: float, - step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0, - skip_zeros=False) -> None: - + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + absmax1: Tensor, + absmax2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer8bit_blockwise[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer8bit_blockwise[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) -def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5): +def percentile_clipping( + grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 +): """Applies percentile clipping grad: torch.Tensor @@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: """ if grad.dtype == torch.float32: - lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) + lib.cpercentile_clipping_g32( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) elif grad.dtype == torch.float16: - lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) + lib.cpercentile_clipping_g16( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) else: - raise ValueError(f'Gradient type {grad.dtype} not supported!') + raise ValueError(f"Gradient type {grad.dtype} not supported!") current_gnorm = torch.sqrt(gnorm_vec[step % 100]) vals, idx = torch.sort(gnorm_vec) @@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: gnorm_scale = 1.0 if current_gnorm > clip_value: - gnorm_scale = clip_value/current_gnorm + gnorm_scale = clip_value / current_gnorm return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): +def histogram_scatter_add_2d( + histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor +): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 assert index1.dtype == torch.int32 assert index2.dtype == torch.int32 - assert histogram.device.type == 'cuda' - assert index1.device.type == 'cuda' - assert index2.device.type == 'cuda' - assert source.device.type == 'cuda' + assert histogram.device.type == "cuda" + assert index1.device.type == "cuda" + assert index2.device.type == "cuda" + assert source.device.type == "cuda" maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) - lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + lib.chistogram_scatter_add_2d( + get_ptr(histogram), + get_ptr(index1), + get_ptr(index2), + get_ptr(source), + maxdim1, + n, + ) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}') + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) sA = A.shape sB = B.shape @@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 correct = True if len(sA) == 2 and len(sB) == 2: - if not tA and not tB and A.shape[1] != B.shape[0]: correct = False - elif tA and not tB and A.shape[0] != B.shape[0]: correct = False - elif tA and tB and A.shape[0] != B.shape[1]: correct = False - elif not tA and tB and A.shape[1] != B.shape[1]: correct = False + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB and A.shape[2] != B.shape[0]: correct = False - elif tA and not tB and A.shape[1] != B.shape[0]: correct = False - elif tA and tB and A.shape[1] != B.shape[1]: correct = False - elif not tA and tB and A.shape[2] != B.shape[1]: correct = False + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB and A.shape[2] != B.shape[1]: correct = False - elif tA and not tB and A.shape[1] != B.shape[1]: correct = False - elif tA and tB and A.shape[1] != B.shape[2]: correct = False - elif not tA and tB and A.shape[2] != B.shape[2]: correct = False + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False if out is not None: sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if (sout[0] == sA[2] and sout[1] == sB[2] and - sA[0] == sB[0] and sA[1] == sB[1]): + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): correct = True else: if len(sA) == 2 and len(sB) == 2: - if not tA and not tB: sout = (sA[0], sB[1]) - elif tA and tB: sout = (sA[1], sB[0]) - elif tA and not tB: sout = (sA[1], sB[1]) - elif not tA and tB: sout = (sA[0], sB[0]) + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB: sout = (sA[0], sA[1], sB[1]) - elif tA and tB: sout = (sA[0], sA[2], sB[0]) - elif tA and not tB: sout = (sA[0], sA[2], sB[1]) - elif not tA and tB: sout = (sA[0], sA[1], sB[0]) + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB: sout = (sA[0], sA[1], sB[2]) - elif tA and tB: sout = (sA[0], sA[2], sB[1]) - elif tA and not tB: sout = (sA[0], sA[2], sB[2]) - elif not tA and tB: sout = (sA[0], sA[1], sB[1]) - + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) if not correct: - raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.') + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + ) return sout -def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): + +def igemm( + A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False +): sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) if len(A.shape) == 3 and len(B.shape) == 3: if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]: return batched_igemm(A, B, out) sA = A.shape sB = B.shape - if transposed_A and len(sA) == 2: sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0]) + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) # this is a mess: cuBLAS expect column major, but PyTorch is row major. # So to perform the matrix multiplication, we have to treat A, B, and C matrices # (transpose of row major is column major) @@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] if len(sB) == 2: - if B.stride()[0] == B.shape[1]: transposed_B = False - elif B.stride()[1] == B.shape[0]: transposed_B = True + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: transposed_A = False - elif A.stride()[1] == A.shape[0]: transposed_A = True + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True else: - if A.stride()[1] == A.shape[2]: transposed_A = False - elif A.stride()[2] == A.shape[1]: transposed_A = True + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True if len(sA) == 2: n = sA[0] ldb = A.stride()[1 if transposed_A else 0] elif len(sA) == 3 and len(sB) == 2: - n = sA[0]*sA[1] + n = sA[0] * sA[1] ldb = sA[2] - m = sB[1] k = sB[0] lda = B.stride()[(1 if transposed_B else 0)] @@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # special case assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}') + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) transposed_A = True transposed_B = False m = sB[2] n = sA[2] - k = sB[0]*sB[1] + k = sB[0] * sB[1] lda = m ldb = sA[2] ldc = m - ptr = CUBLAS_Context.get_instance().get_context(A.device) # B^T @ A^T = C^T - # [km, nk -> mn] - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + # [km, nk -> mn] + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out -def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): +def batched_igemm( + A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False +): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}') + raise ValueError( + f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" + ) sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) if B.is_contiguous(): lda = B.stride()[1] @@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ldc = m - strideA = B.shape[1]*B.shape[2] - strideB = A.shape[1]*A.shape[2] - strideC = A.shape[1]*B.shape[2] + strideA = B.shape[1] * B.shape[2] + strideB = A.shape[1] * A.shape[2] + strideC = A.shape[1] * B.shape[2] ptr = CUBLAS_Context.get_instance().get_context(A.device) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out + def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] @@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): if dimsA == 2: m = shapeA[0] elif dimsA == 3: - m = shapeA[0]*shapeA[1] + m = shapeA[0] * shapeA[1] if dimsB == 2: rows = n = shapeB[0] elif dimsB == 3: - rows = n = shapeB[0]*shapeB[1] + rows = n = shapeB[0] * shapeB[1] if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) - assert dimsB != 3, 'len(B.shape)==3 not supported' - assert A.device.type == 'cuda' - assert B.device.type == 'cuda' + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" assert A.dtype == torch.int8 assert B.dtype == torch.int8 assert out.dtype == dtype - assert SA[1] == 'col32' - assert SB[1] in ['col_turing', 'col_ampere'] - assert Sout[1] == 'col32' - assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}' + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" formatB = SB[1] prev_device = A.device torch.cuda.set_device(A.device) @@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ptrC = get_ptr(out) k = shapeA[-1] - lda = ct.c_int32(m*32) - if formatB == 'col_turing': + lda = ct.c_int32(m * 32) + if formatB == "col_turing": # turing: tiles with rows filled up to multiple of 8 rows by 32 columns # n = rows - ldb = ct.c_int32(((rows+7)//8)*8*32) + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) else: # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns # n = rows - ldb = ct.c_int32(((rows+31)//32)*32*32) + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - ldc = ct.c_int32(m*32) + ldc = ct.c_int32(m * 32) m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) has_error = 0 ptrRowScale = get_ptr(None) - if formatB == 'col_turing': + if formatB == "col_turing": if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif formatB == 'col_ampere': + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) if has_error == 1: - raise Exception('cublasLt ran into an error!') + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) - return out, Sout -def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None): +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, +): assert A.dtype == torch.int32 out_shape = quant_state[0] - if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2]) - - if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) - if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + if new_col_stats is None: + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" ptrA = get_ptr(A) ptrOut = get_ptr(out) @@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + numRows, + numCols, + ) return out -def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): assert A.dtype == torch.float16 device = A.device cols = A.shape[-1] if len(A.shape) == 3: - rows = A.shape[0]*A.shape[1] + rows = A.shape[0] * A.shape[1] else: rows = A.shape[0] - col_tiles = (cols+255)//256 - tiled_rows = ((rows+15)//16)*16 - if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) - if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device) + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + if col_stats is None: + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr cols = ct.c_int32(cols) prev_device = pre_call(A.device) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) + lib.cget_col_row_stats( + ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols + ) post_call(prev_device) - if threshold > 0.0: nnz_block_ptr.cumsum_(0) - return row_stats, col_stats, nnz_block_ptr + class COOSparseTensor(object): def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 @@ -1050,6 +1488,7 @@ class COOSparseTensor(object): self.colidx = colidx self.values = values + class CSRSparseTensor(object): def __init__(self, rows, cols, nnz, rowptr, colidx, values): assert rowptr.dtype == torch.int32 @@ -1057,7 +1496,7 @@ class CSRSparseTensor(object): assert values.dtype == torch.float16 assert values.numel() == nnz assert colidx.numel() == nnz - assert rowptr.numel() == rows+1 + assert rowptr.numel() == rows + 1 self.rows = rows self.cols = cols @@ -1066,6 +1505,7 @@ class CSRSparseTensor(object): self.colidx = colidx self.values = values + class CSCSparseTensor(object): def __init__(self, rows, cols, nnz, colptr, rowidx, values): assert colptr.dtype == torch.int32 @@ -1073,7 +1513,7 @@ class CSCSparseTensor(object): assert values.dtype == torch.float16 assert values.numel() == nnz assert rowidx.numel() == nnz - assert colptr.numel() == cols+1 + assert colptr.numel() == cols + 1 self.rows = rows self.cols = cols @@ -1082,13 +1522,17 @@ class CSCSparseTensor(object): self.rowidx = rowidx self.values = values + def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) + return CSRSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values + ) + def coo2csc(cooA): val, col2rowidx = torch.sort(cooA.colidx) @@ -1096,11 +1540,12 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + def coo_zeros(rows, cols, nnz, device, dtype=torch.half): rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) @@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): device = A.device assert A.dtype == torch.half - assert device.type == 'cuda' + assert device.type == "cuda" prev_device = pre_call(A.device) cols = A.shape[-1] if len(A.shape) == 3: - rows = A.shape[0]*A.shape[1] + rows = A.shape[0] * A.shape[1] else: rows = A.shape[0] if row_stats is None or col_stats is None: row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) - if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) coo_tensor = None ptrA = get_ptr(A) @@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) ptrRowPtr = get_ptr(nnz_row_ptr) - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) val, idx = torch.sort(coo_tensor.rowidx) coo_tensor.rowidx = val coo_tensor.colidx = coo_tensor.colidx[idx] coo_tensor.values = coo_tensor.values[idx] else: - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) else: - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) post_call(prev_device) return out_row, out_col, row_stats, col_stats, coo_tensor @@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, def get_special_format_str(): major, minor = torch.cuda.get_device_capability() if major < 7: - print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + print( + f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!" + ) assert major >= 7 - if major == 7: return 'col_turing' - elif major == 8: return 'col_ampere' - else: return 'col_turing' - - + if major == 7: + return "col_turing" + elif major == 8: + return "col_ampere" + else: + return "col_turing" -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) +def transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1], transpose + ) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: dim1 = ct.c_int32(shape[0]) dim2 = ct.c_int32(shape[1]) else: - dim1 = ct.c_int32(shape[0]*shape[1]) + dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) ptrA = get_ptr(A) ptrOut = get_ptr(out) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'col_turing': + elif to_order == "col_turing": if transpose: lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'col_ampere': + elif to_order == "col_ampere": if transpose: lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'row': - if from_order == 'col_turing': + elif to_order == "row": + if from_order == "col_turing": lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == 'col_ampere': + elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - - + raise NotImplementedError( + f"Transform function not implemented: From {from_order} to {to_order}" + ) return out, new_state + def spmm_coo(cooA, B, out=None): - if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + if out is None: + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz assert cooA.cols == B.shape[0] - transposed_B = (False if B.is_contiguous() else True) + transposed_B = False if B.is_contiguous() else True ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None): cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out + def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): - if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) + if out is None: + out = torch.zeros( + (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}' + assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" - transposed_B = (False if B.is_contiguous() else True) + transposed_B = False if B.is_contiguous() else True ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.' + assert ( + max_count[0] <= 32 + ), f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB = ct.c_int32(B.shape[1]) cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - #print(cooA.rowidx[:64]) - #print(cooA.colidx[:64].sort()[0]) + # print(cooA.rowidx[:64]) + # print(cooA.colidx[:64].sort()[0]) if B.dtype == torch.float16: - lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) + lib.cspmm_coo_very_sparse_naive_fp16( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) elif B.dtype == torch.int8: - lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) - #else: assertion error + lib.cspmm_coo_very_sparse_naive_int8( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + # else: assertion error return out C = 127.0 -def vectorwise_quant(x, dim=1, quant_type='vector'): - if quant_type == 'linear': + +def vectorwise_quant(x, dim=1, quant_type="vector"): + if quant_type == "linear": max1 = torch.abs(x).max().float() - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type in ['vector', 'row']: + elif quant_type in ["vector", "row"]: max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x*(C/max1)).to(torch.int8) + xq = torch.round(x * (C / max1)).to(torch.int8) return xq, max1 - elif quant_type == 'zeropoint': + elif quant_type == "zeropoint": dtype = x.dtype x = x.float() dyna = x.max() - x.min() - if dyna == 0: dyna = 1 - qx = 255./dyna + if dyna == 0: + dyna = 1 + qx = 255.0 / dyna minx = x.min() - zpx = torch.round(minx* qx) - x = torch.round(qx*x - zpx) + zpx + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx return x, qx - elif quant_type in ['vector-zeropoint', 'row-zeropoint']: + elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)) - dyna[dyna==0] = 1 - qx = 255./dyna + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( + x, dim=dim, keepdim=True + ) + dyna[dyna == 0] = 1 + qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) - zpx = torch.round(minx* qx) - x = torch.round(qx*x - zpx) + zpx + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx return x, qx - elif quant_type == 'truncated-vector': + elif quant_type == "truncated-vector": with torch.no_grad(): absx = torch.abs(x) max1 = torch.amax(absx, dim=dim, keepdim=True) - max1 = max1*0.7 - idx = (absx > max1.expand_as(absx)) + max1 = max1 * 0.7 + idx = absx > max1.expand_as(absx) sign = torch.sign(x[idx]) - x[idx] = max1.expand_as(absx)[idx]*sign - xq = torch.round(x/max1*C).to(torch.int8) + x[idx] = max1.expand_as(absx)[idx] * sign + xq = torch.round(x / max1 * C).to(torch.int8) return xq, max1 - else: return None + else: + return None + -def vectorwise_dequant(xq, max1, quant_type='vector'): - if quant_type == 'vector': - x = (xq/C*max1).to(torch.float32) +def vectorwise_dequant(xq, max1, quant_type="vector"): + if quant_type == "vector": + x = (xq / C * max1).to(torch.float32) return x - else: return None + else: + return None -def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'): - if quant_type == 'linear': - norm = S1*S2/(C*C) + +def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): + if quant_type == "linear": + norm = S1 * S2 / (C * C) # double cast needed to prevent overflows - return (xq.float()*norm).to(dtype) - elif quant_type == 'zeropoint': - norm = 1.0/(S1*S2) - return (xq.float()*norm).to(dtype) - elif quant_type == 'row-zeropoint': - norm = 1.0/(S1*S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "zeropoint": + norm = 1.0 / (S1 * S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "row-zeropoint": + norm = 1.0 / (S1 * S2) x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: x *= norm else: x *= norm return x.to(dtype) - elif quant_type == 'vector-zeropoint': + elif quant_type == "vector-zeropoint": x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= 1.0/S1 + x *= 1.0 / S1 else: - x *= 1.0/S1 - x *= 1.0/S2.t() + x *= 1.0 / S1 + x *= 1.0 / S2.t() return x.to(dtype) - elif quant_type == 'row': + elif quant_type == "row": x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= S1*S2/(C*C) + x *= S1 * S2 / (C * C) else: - x *= S1*S2/(C*C) + x *= S1 * S2 / (C * C) return x.to(dtype) - elif quant_type in ['truncated-vector', 'vector']: + elif quant_type in ["truncated-vector", "vector"]: x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= S1/C + x *= S1 / C else: - x *= S1/C - x *= S2/C + x *= S1 / C + x *= S2 / C return x.to(dtype) - else: return None + else: + return None def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): - offset = B.float().t().sum(0)*(SA[0]+SA[1]) + offset = B.float().t().sum(0) * (SA[0] + SA[1]) x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) if len(SB.shape) == 2: - x *= SB.t()/127 + x *= SB.t() / 127 else: - x *= SB/127 - x *= SA[1]/127 - x +=offset + x *= SB / 127 + x *= SA[1] / 127 + x += offset return x.to(dtype) + def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - assert formatA in ['col_turing', 'col_ampere'] - assert A.device.type == 'cuda' + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) @@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx): ptrIdx = get_ptr(idx) ptrOut = get_ptr(out) - if formatA == 'col_turing': + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == 'col_ampere': + elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) return out - - - - -- cgit v1.2.3