From c771b3a75a6ebbfbfc398a028a477246b0799cf0 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Fri, 22 Jul 2022 14:41:05 -0700 Subject: Most tests passing. --- bitsandbytes/__init__.py | 3 +- bitsandbytes/autograd/__init__.py | 0 bitsandbytes/autograd/_functions.py | 307 ++++++ bitsandbytes/cextension.py | 2 + bitsandbytes/functional.py | 869 ++++++++++++++++- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 124 ++- csrc/kernels.cu | 874 +++++++++++++++++ csrc/kernels.cuh | 12 + csrc/ops.cu | 406 ++++++++ csrc/ops.cuh | 104 +++ csrc/pythonInterface.c | 127 ++- tests/test_autograd.py | 270 ++++++ tests/test_functional.py | 1763 +++++++++++++++++++++++++++++++++-- tests/test_modules.py | 478 +++++++++- tests/test_optim.py | 87 +- 16 files changed, 5269 insertions(+), 159 deletions(-) create mode 100644 bitsandbytes/autograd/__init__.py create mode 100644 bitsandbytes/autograd/_functions.py create mode 100644 tests/test_autograd.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 02ca804..3c3affa 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -4,12 +4,13 @@ # LICENSE file in the root directory of this source tree. from .nn import modules +from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState from .cextension import COMPILED_WITH_CUDA if COMPILED_WITH_CUDA: from .optim import adam -__pdoc__ = {'libBitsNBytes': False, +__pdoc__ = {'libbitsandbytes': False, 'optim.optimizer.Optimizer8bit': False, 'optim.optimizer.MockArgs': False } diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py new file mode 100644 index 0000000..815a4f1 --- /dev/null +++ b/bitsandbytes/autograd/_functions.py @@ -0,0 +1,307 @@ +import torch +import bitsandbytes as bnb +import bitsandbytes.functional as F + +from dataclasses import dataclass + +tensor = torch.Tensor + +''' + This class pools outlier dimensions across layers. + This is particularly important for small models where outlier features + are less systematic and occur with low frequency. +''' +class GlobalOutlierPooler(object): + _instance = None + + def __init__(self): + raise RuntimeError('Call get_instance() instead') + + def initialize(self): + self.outliers = set() + self.model_dim = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def add_outliers(self, outlier_idx, feature_dim): + if self.model_dim is None: self.model_dim = feature_dim + if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer + + self.outliers.update(outlier_idx.tolist()) + + def get_current_outlier_idx(self): + return torch.Tensor(list(self.outliers)).to(torch.int64) + +class MatMul8bit(torch.autograd.Function): + + @staticmethod + def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]): + + if precision[0] != 8: + with torch.no_grad(): + output = torch.matmul(A, B) + else: + if len(B.shape) == 2: dim = 0 + else: dim = 1 + qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) + qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) + iout = F.igemm(qA, qB) + output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type) + + if A.requires_grad or B.requires_grad: + ctx.save_for_backward(A, B) + + ctx.quant_type = quant_type + ctx.precision = precision + + return output + + @staticmethod + def backward(ctx, grad_output): + A, B = ctx.saved_tensors + quant_type = ctx.quant_type + precision = ctx.precision + grad_A = grad_B = None + + if B.requires_grad: + if len(A.shape) == 3: + dims = [0, 1] + # bsi -> ibs + permute_dim = [0, 2, 1] + else: + dims = [0] + # bs -> sb + permute_dim = [1, 0] + + if precision[1] != 8: + with torch.no_grad(): + grad_B = torch.matmul(A.permute(permute_dim), grad_output) + else: + if len(B.shape) == 2 and len(A.shape) == 3: + grad_output = grad_output.contiguous() + if not grad_output.is_contiguous(): grad_output.contiguous() + qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type) + if not A.is_contiguous(): A = A.contiguous() + qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) + igrad_B = F.igemm(qA.t(), qgrad_output) + grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) + else: + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) + igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) + grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type) + + if A.requires_grad: + if len(grad_output.shape) == 3: dims = [2] + else: dims = [1] + + if len(B.shape) == 3: + # bio -> boi + permute_dim = [0, 2, 1] + dim_B = dims + else: + # io -> oi + permute_dim = [1, 0] + dim_B = [1] + + if precision[2] != 8: + with torch.no_grad(): + grad_A = torch.matmul(grad_output, B.permute(permute_dim)) + else: + qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) + igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) + grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type) + + return grad_A, grad_B, None, None, None + + +mm_cublas = MatMul8bit.apply +bmm_cublas = MatMul8bit.apply +matmul_cublas = MatMul8bit.apply + +@dataclass +class MatmulLtState: + CB = None + CxB = None + SB = None + SCB = None + + CxBt = None + SBt = None + CBt = None + + subB = None + + outlier_pool = None + has_accumulated_gradients = False + threshold = 0.0 + idx = None + is_training = True + has_fp16_weights = True + use_pool = False + formatB = F.get_special_format_str() + + def reset_grads(self): + self.CB = None + self.CxB = None + self.SB = None + self.SCB = None + + self.CxBt = None + self.SBt = None + self.CBt = None + + +class MatMul8bitLt(torch.autograd.Function): + + @staticmethod + def forward(ctx, A, B, out=None, state=MatmulLtState()): + # 1. Quantize A + # 2. Quantize B + # 3. Matmul + # 4. Mixed-precision decomposition matmul + # 5. Save state + requires_gradA = A.requires_grad + requires_gradB = B.requires_grad + formatB = state.formatB + input_shape = A.shape + if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() + assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!' + + # 1. Quantize A + if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) + + if state.threshold > 0.0 and coo_tensorA is not None: + if state.has_fp16_weights: + idx = torch.unique(coo_tensorA.colidx).long() + CA[:, idx] = 0 + CAt[:, idx] = 0 + subA = A[:, idx] + state.subB = B[:, idx].t().contiguous() + state.idx = idx + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: + # generate outlier index and subB + outlier_idx = torch.unique(coo_tensorA.colidx).long() + state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # do not use pool for 2nd FFN layer + state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + else: + state.idx = outlier_idx + state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() + + if state.idx is not None: + # extract outliers + CA[:, state.idx] = 0 + CAt[:, state.idx] = 0 + subA = A[:, state.idx] + else: + subA = None + else: + if not state.has_fp16_weights and state.CxB is None: + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + subA = None + + C32A, SA = F.transform(CA, 'col32') + + # 2. Quantize B + if state.has_fp16_weights: + has_grad = (True if (getattr(B, 'grad', None) is not None) else False) + is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) + if is_transposed: B = B.contiguous() + + if (state.is_training and not has_grad) or state.CxB is None: + state.reset_grads() + CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B) + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + has_grad = False + + shapeB = state.SB[0] + + if len(input_shape) == 3: + output_shape = (input_shape[0], input_shape[1], shapeB[0]) + else: + output_shape = (input_shape[0], shapeB[0]) + + # 3. Matmul + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + output = F.mm_dequant(out32, Sout32, SCA, state.SCB) + + # 4. Mixed-precision decomposition matmul + if state.threshold > 0.0 and coo_tensorA is not None and subA is not None: + output += torch.matmul(subA, state.subB) + + # 5. Save state + ctx.state = state + + ctx.formatB = formatB + ctx.grad_shape = input_shape + ctx.req_grads = [requires_gradA, requires_gradB] + + if requires_gradA or requires_gradB: + ctx.tensors = (CAt, subA) + ctx.tensor_states = (SCAt, state.idx) + else: + ctx.tensors = [None, None] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + clone_func = torch.clone + return clone_func(output.view(output_shape)) + + @staticmethod + def backward(ctx, grad_output): + req_gradA, req_gradB = ctx.req_grads + CAt, subA = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB + state = ctx.state + assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.' + + if len(grad_output.shape) == 3: + grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() + + grad_A = grad_B = None + + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + if req_gradB: + CxAt, SAt = F.transform(CAt, formatB, transpose=True) + C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True) + gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + + if req_gradA: + C32grad, Sgrad = F.transform(Cgrad, 'col32') + if state.CxBt is None: + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + + return grad_A, grad_B, None, None, None, None, None + + +matmul = MatMul8bitLt.apply + + +def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0): + state = state or MatmulLtState() + if threshold > 0.0: + state.threshold = threshold + return MatMul8bitLt.apply(A, B, out, state) + diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 63d627e..2374c35 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') try: lib.cadam32bit_g32 + lib.get_context.restype = ct.c_void_p + lib.get_cusparse.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ab4e565..806c254 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -36,9 +36,51 @@ if COMPILED_WITH_CUDA: str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16) str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16) -optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0] -optimal_half_normal = [0.0025565922260284424, 0.005811259150505066, 0.00961565226316452, 0.010822802782058716, 0.013123787939548492, 0.014242202043533325, 0.0143156498670578, 0.016469404101371765, 0.017666727304458618, 0.01773911714553833, 0.0199756920337677, 0.0210941880941391, 0.021161124110221863, 0.02451971173286438, 0.024580076336860657, 0.02685210108757019, 0.028012827038764954, 0.030198264867067337, 0.0302925705909729, 0.03136435151100159, 0.03374280035495758, 0.03487399220466614, 0.035243816673755646, 0.037192340940237045, 0.03822284936904907, 0.04164902865886688, 0.04173608124256134, 0.04401407018303871, 0.04508155584335327, 0.047482021152973175, 0.04756556823849678, 0.050963032990694046, 0.05196474492549896, 0.055417388677597046, 0.05793146416544914, 0.05799369141459465, 0.05887940526008606, 0.05895659327507019, 0.062420234084129333, 0.06493274495005608, 0.06499008461833, 0.06935599446296692, 0.07197384163737297, 0.07201516255736351, 0.07276943325996399, 0.07283210754394531, 0.07550075277686119, 0.07975354790687561, 0.07980883121490479, 0.08257630094885826, 0.0867777168750763, 0.08682405948638916, 0.08967285975813866, 0.09323835000395775, 0.09386616945266724, 0.09735457599163055, 0.09739077091217041, 0.10092401504516602, 0.10444298386573792, 0.10447832942008972, 0.10770941898226738, 0.10803905129432678, 0.11161200702190399, 0.1151546835899353, 0.11520349979400635, 0.11875157058238983, 0.11879390478134155, 0.1222602017223835, 0.122351735830307, 0.12240418791770935, 0.12594850733876228, 0.12597402930259705, 0.12602100148797035, 0.12960633635520935, 0.1296597123146057, 0.12966342642903328, 0.13227657973766327, 0.13325360417366028, 0.1333133578300476, 0.13691483438014984, 0.1371927298605442, 0.14066261053085327, 0.14088113978505135, 0.1447291411459446, 0.14805573225021362, 0.148526418954134, 0.15170684456825256, 0.15178103744983673, 0.15225710347294807, 0.1554398238658905, 0.15609459951519966, 0.15618794038891792, 0.1592724472284317, 0.1629735231399536, 0.16382690146565437, 0.16676269471645355, 0.16873238794505596, 0.17066434025764465, 0.17068277299404144, 0.1717144437134266, 0.17558929696679115, 0.17827065289020538, 0.17835864424705505, 0.18222273886203766, 0.18353315070271492, 0.18604370951652527, 0.18611834943294525, 0.1876586265861988, 0.18996606767177582, 0.19170701876282692, 0.19398853182792664, 0.19786442816257477, 0.19795633852481842, 0.20195159316062927, 0.2058800607919693, 0.2099103182554245, 0.2122517265379429, 0.21410366892814636, 0.21819619834423065, 0.22221362590789795, 0.22233009338378906, 0.22500130906701088, 0.2251257635653019, 0.22638091444969177, 0.23067741096019745, 0.23368822410702705, 0.2348879873752594, 0.2382080741226673, 0.2390350103378296, 0.2391497790813446, 0.24253453686833382, 0.24265171959996223, 0.2470107562839985, 0.24764248728752136, 0.24777774512767792, 0.2516774423420429, 0.256104726344347, 0.2564055472612381, 0.2607169933617115, 0.265461727976799, 0.26985861361026764, 0.2701106257736683, 0.2702729292213917, 0.274574413895607, 0.2750340588390827, 0.27919672429561615, 0.283704474568367, 0.28386808931827545, 0.28953738883137703, 0.2896753139793873, 0.29320384562015533, 0.29451676085591316, 0.295327290892601, 0.29802779853343964, 0.29818175733089447, 0.29972871020436287, 0.30290623009204865, 0.30305664241313934, 0.30486901476979256, 0.31299956142902374, 0.31518544629216194, 0.31790371239185333, 0.3205283172428608, 0.3230419009923935, 0.32595496252179146, 0.32612212374806404, 0.3282426446676254, 0.3283906430006027, 0.33146094158291817, 0.3316439874470234, 0.33365286886692047, 0.33723779395222664, 0.3390095978975296, 0.3427443392574787, 0.34853987768292427, 0.34869300201535225, 0.35457711294293404, 0.35537679493427277, 0.3604113645851612, 0.36124424636363983, 0.3665340431034565, 0.36667295172810555, 0.3727492541074753, 0.3729033060371876, 0.37888188660144806, 0.37907837703824043, 0.3792510814964771, 0.38557394221425056, 0.38573457673192024, 0.39108292758464813, 0.39911722019314766, 0.40589402988553047, 0.40604450181126595, 0.410498782992363, 0.4106704741716385, 0.4129834659397602, 0.4131447561085224, 0.4172855168581009, 0.4202354736626148, 0.4204071946442127, 0.43538858368992805, 0.4355536885559559, 0.4432900734245777, 0.44603554904460907, 0.4461968094110489, 0.451409537345171, 0.4598204083740711, 0.46002377942204475, 0.46178819239139557, 0.46868549659848213, 0.46995367109775543, 0.4868385046720505, 0.48702501133084297, 0.4958047419786453, 0.4960057884454727, 0.5051481872797012, 0.506847757846117, 0.5148334950208664, 0.5150565356016159, 0.5174009390175343, 0.5249751061201096, 0.5283288545906544, 0.5355450958013535, 0.539984006434679, 0.5467876642942429, 0.5522958822548389, 0.5584012717008591, 0.5706631988286972, 0.5836620181798935, 0.5836880058050156, 0.5942088551819324, 0.5975865572690964, 0.6102624125778675, 0.6124880760908127, 0.6286389082670212, 0.646102175116539, 0.6471664495766163, 0.665437325835228, 0.6687244363129139, 0.687017485499382, 0.6932839937508106, 0.7115348428487778, 0.7218200154602528, 0.7219699807465076, 0.7747527211904526, 0.7749756425619125, 0.8192005604505539, 0.8194110840559006, 0.8830635994672775, 0.9217727445065975, 0.9245667457580566, 0.947742685675621, 0.9674464613199234, 0.9890814647078514, 0.9891453236341476, 0.9925699159502983] +class CUBLAS_Context(object): + _instance = None + + def __init__(self): + raise RuntimeError('Call get_instance() instead') + + def initialize(self): + self.context = {} + #prev_device = torch.cuda.current_device() + #for i in range(torch.cuda.device_count()): + # torch.cuda.set_device(torch.device('cuda', i)) + # self.context.append(ct.c_void_p(lib.get_context())) + #torch.cuda.set_device(prev_device) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def get_context(self, device): + if device.index not in self.context: + prev_device = torch.cuda.current_device() + torch.cuda.set_device(device) + self.context[device.index] = ct.c_void_p(lib.get_context()) + torch.cuda.set_device(prev_device) + return self.context[device.index] + +class Cusparse_Context(object): + _instance = None + + def __init__(self): + raise RuntimeError('Call get_instance() instead') + + def initialize(self): + self.context = ct.c_void_p(lib.get_cusparse()) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance def create_linear_map(signed=True): if signed: @@ -89,6 +131,16 @@ def create_dynamic_map(signed=True, n=7): data.sort() return Tensor(data) +def get_special_format_str(): + major, minor = torch.cuda.get_device_capability() + if major < 7: + print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + assert major >= 7 + + if major == 7: return 'col_turing' + elif major == 8: return 'col_ampere' + else: return 'col_turing' + def get_ptr(A: Tensor) -> ct.c_void_p: ''' Get the ctypes pointer from a PyTorch Tensor. @@ -105,6 +157,105 @@ def get_ptr(A: Tensor) -> ct.c_void_p: if A is None: return None else: return ct.c_void_p(A.data.storage().data_ptr()) +def pre_call(device): + prev_device = torch.cuda.current_device() + torch.cuda.set_device(device) + return prev_device + +def post_call(prev_device): + torch.cuda.set_device(prev_device) + +def get_transform_func(dtype, orderA, orderOut, transpose=False): + name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' + if not hasattr(lib, name): + print(name) + raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}') + else: + return getattr(lib, name) + +class GlobalData(object): + _instance = None + + def __init__(self): + raise RuntimeError('Call get_instance() instead') + + def initialize(self): + self.data = {} + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + +def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False): + #init_func = torch.empty + init_func = torch.zeros + dims = len(shape) + + if dims == 2: + rows = shape[0] + elif dims == 3: + rows = shape[0]*shape[1] + cols = shape[-1] + + state = (shape, to_order) + if transpose: + # swap dims + tmp = rows + rows = cols + cols = tmp + state = (shape[::-1], to_order) + + if to_order == 'row' or to_order == 'col': + return init_func(shape, dtype=dtype, device=device), state + elif to_order == 'col32': + # blocks of 32 columns (padded) + cols = 32*((cols+31)//32) + return init_func((rows, cols), dtype=dtype, device=device), state + elif to_order == 'col_turing': + # blocks of 32 columns and 8 rows + cols = 32*((cols+31)//32) + rows = 8*((rows+7)//8) + return init_func((rows, cols), dtype=dtype, device=device), state + elif to_order == 'col_ampere': + # blocks of 32 columns and 32 rows + cols = 32*((cols+31)//32) + rows = 32*((rows+31)//32) + return init_func((rows, cols), dtype=dtype, device=device), state + else: + raise NotImplementedError(f'To_order not supported: {to_order}') + +def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) + else: new_state = (state[1], to_order) + func = get_transform_func(A.dtype, from_order, to_order, transpose) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + elif ld is not None: + n = math.prod(shape) + dim1 = math.prod([shape[i] for i in ld]) + dim2 = ct.c_int32(n//dim1) + dim1 = ct.c_int32(dim1) + else: + dim1 = ct.c_int32(shape[0]*shape[1]) + dim2 = ct.c_int32(shape[2]) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) + + + return out, new_state + def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor: ''' Estimates 256 equidistant quantiles on the input tensor eCDF. @@ -544,3 +695,717 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + +def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): + if not torch.cuda.is_initialized(): torch.cuda.init() + if A.dtype != expected_type or B.dtype != expected_type: + raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}') + + sA = A.shape + sB = B.shape + tA = transposed_A + tB = transposed_B + + correct = True + + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB and A.shape[1] != B.shape[0]: correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: correct = False + elif tA and tB and A.shape[0] != B.shape[1]: correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: correct = False + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB and A.shape[2] != B.shape[0]: correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: correct = False + elif tA and tB and A.shape[1] != B.shape[1]: correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: correct = False + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB and A.shape[2] != B.shape[1]: correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: correct = False + elif tA and tB and A.shape[1] != B.shape[2]: correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: correct = False + + if out is not None: + sout = out.shape + # special case common in backprop + if not correct and len(sA) == 3 and len(sB) == 3: + if (sout[0] == sA[2] and sout[1] == sB[2] and + sA[0] == sB[0] and sA[1] == sB[1]): + correct = True + else: + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB: sout = (sA[0], sB[1]) + elif tA and tB: sout = (sA[1], sB[0]) + elif tA and not tB: sout = (sA[1], sB[1]) + elif not tA and tB: sout = (sA[0], sB[0]) + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB: sout = (sA[0], sA[1], sB[1]) + elif tA and tB: sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: sout = (sA[0], sA[1], sB[0]) + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB: sout = (sA[0], sA[1], sB[2]) + elif tA and tB: sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: sout = (sA[0], sA[1], sB[1]) + + + if not correct: + raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.') + + return sout + +def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): + sout = check_matmul(A, B, out, transposed_A, transposed_B) + if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if len(A.shape) == 3 and len(B.shape) == 3: + if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]: + return batched_igemm(A, B, out) + + sA = A.shape + sB = B.shape + if transposed_A and len(sA) == 2: sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0]) + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + + # matrices in the input arguments for cuBLAS + # column major: A @ B = C: [m, k] @ [k, n] = [m, n] + # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] + # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] + if len(sB) == 2: + if B.stride()[0] == B.shape[1]: transposed_B = False + elif B.stride()[1] == B.shape[0]: transposed_B = True + if len(A.shape) == 2: + if A.stride()[0] == A.shape[1]: transposed_A = False + elif A.stride()[1] == A.shape[0]: transposed_A = True + else: + if A.stride()[1] == A.shape[2]: transposed_A = False + elif A.stride()[2] == A.shape[1]: transposed_A = True + + if len(sA) == 2: + n = sA[0] + ldb = A.stride()[1 if transposed_A else 0] + elif len(sA) == 3 and len(sB) == 2: + n = sA[0]*sA[1] + ldb = sA[2] + + + m = sB[1] + k = sB[0] + lda = B.stride()[(1 if transposed_B else 0)] + ldc = sB[1] + elif len(sB) == 3: + # special case + assert len(sA) == 3 + if not (sA[0] == sB[0] and sA[1] == sB[1]): + raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}') + + transposed_A = True + transposed_B = False + + m = sB[2] + n = sA[2] + k = sB[0]*sB[1] + + lda = m + ldb = sA[2] + ldc = m + + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + # B^T @ A^T = C^T + # [km, nk -> mn] + lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + return out + + +def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): + if not len(A.shape) == 3 or not len(B.shape) == 3: + raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}') + sout = check_matmul(A, B, out, transposed_A, transposed_B) + if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + + if B.is_contiguous(): + lda = B.stride()[1] + transposed_A = False + else: + s = B.stride() + if s[0] != B.shape[0]: + B = B.contiguous() + lda = B.stride()[1] + elif s[2] == B.shape[1]: + transposed_A = True + lda = B.stride()[2] + else: + if s[2] == 1: + B = B.contiguous() + lda = B.stride()[1] + elif s[1] == 1: + B = B.contiguous() + lda = B.stride()[1] + else: + B = B.contiguous() + lda = B.stride()[1] + + if A.is_contiguous(): + ldb = A.stride()[1] + transposed_B = False + else: + s = A.stride() + if s[0] != A.shape[0]: + A = A.contiguous() + ldb = A.stride()[1] + transposed_B = False + elif s[2] == A.shape[1]: + ldb = A.stride()[2] + transposed_B = True + else: + A = A.contiguous() + ldb = A.stride()[1] + transposed_B = False + + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + # matrices in the input arguments for cuBLAS + + # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n] + # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n] + # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m] + num_batch = A.shape[0] + n = A.shape[1] + m = B.shape[2] + k = B.shape[1] + + ldc = m + + strideA = B.shape[1]*B.shape[2] + strideB = A.shape[1]*A.shape[2] + strideC = A.shape[1]*B.shape[2] + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), + ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + return out + +def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32): + shapeA = SA[0] + shapeB = SB[0] + dimsA = len(shapeA) + dimsB = len(shapeB) + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0]*shapeA[1] + + if dimsB == 2: + rows = n = shapeB[0] + elif dimsB == 3: + rows = n = shapeB[0]*shapeB[1] + + if dimsA == 2 and out is None: + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') + elif dimsA == 3 and out is None: + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') + + if row_scale is not None: assert row_scale.numel() == out.shape[0] + assert dimsB != 3, 'len(B.shape)==3 not supported' + assert A.device.type == 'cuda' + assert B.device.type == 'cuda' + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + assert out.dtype == dtype + assert SA[1] == 'col32' + assert SB[1] in ['col_turing', 'col_ampere'] + assert Sout[1] == 'col32' + assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}' + formatB = SB[1] + prev_device = A.device + torch.cuda.set_device(A.device) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrRowScale = get_ptr(row_scale) + + k = shapeA[-1] + lda = ct.c_int32(m*32) + if formatB == 'col_turing': + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows+7)//8)*8*32) + else: + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows+31)//32)*32*32) + + ldc = ct.c_int32(m*32) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + + has_error = 0 + if formatB == 'col_turing': + if dtype == torch.int32: + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + elif row_scale is None: + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + elif formatB == 'col_ampere': + if dtype == torch.int32: + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + elif row_scale is None: + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + else: + has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + + if has_error == 1: + raise Exception('cublasLt ran into an error!') + + torch.cuda.set_device(prev_device) + + + return out, Sout + + +def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None): + assert A.dtype == torch.int32 + out_shape = quant_state[0] + if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2]) + + if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNewRowStats = get_ptr(new_row_stats) + ptrNewColStats = get_ptr(new_col_stats) + numRows = ct.c_int32(out_shape[0]) + numCols = ct.c_int32(out_shape[1]) + + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) + + return out + + +def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): + assert A.dtype == torch.float16 + device = A.device + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0]*A.shape[1] + else: + rows = A.shape[0] + + col_tiles = (cols+255)//256 + tiled_rows = ((rows+15)//16)*16 + if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) + if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) + + if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device) + + ptrA = get_ptr(A) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNnzrows = get_ptr(nnz_block_ptr) + rows = ct.c_int32(rows) + cols = ct.c_int32(cols) + + prev_device = pre_call(A.device) + lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) + post_call(prev_device) + + + if threshold > 0.0: + nnz_block_ptr.cumsum_(0) + + + return row_stats, col_stats, nnz_block_ptr + +class COOSparseTensor(object): + def __init__(self, rows, cols, nnz, rowidx, colidx, values): + assert rowidx.dtype == torch.int32 + assert colidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colidx.numel() == nnz + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowidx = rowidx + self.colidx = colidx + self.values = values + +class CSRSparseTensor(object): + def __init__(self, rows, cols, nnz, rowptr, colidx, values): + assert rowptr.dtype == torch.int32 + assert colidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert colidx.numel() == nnz + assert rowptr.numel() == rows+1 + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowptr = rowptr + self.colidx = colidx + self.values = values + +class CSCSparseTensor(object): + def __init__(self, rows, cols, nnz, colptr, rowidx, values): + assert colptr.dtype == torch.int32 + assert rowidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colptr.numel() == cols+1 + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.colptr = colptr + self.rowidx = rowidx + self.values = values + +def coo2csr(cooA): + values, counts = torch.unique(cooA.rowidx, return_counts=True) + values.add_(1) + rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device) + rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) + rowptr.cumsum_(0) + return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) + +def coo2csc(cooA): + val, col2rowidx = torch.sort(cooA.colidx) + rowidx = cooA.rowidx[col2rowidx] + values = cooA.values[col2rowidx] + colvalues, counts = torch.unique(val, return_counts=True) + colvalues.add_(1) + colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device) + colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) + colptr.cumsum_(0) + return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + +def coo_zeros(rows, cols, nnz, device, dtype=torch.half): + rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + values = torch.zeros((nnz,), dtype=dtype, device=device) + return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) + + +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + device = A.device + assert A.dtype == torch.half + assert device.type == 'cuda' + prev_device = pre_call(A.device) + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0]*A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + + if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols)) + else: + lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + post_call(prev_device) + + return out_row, out_col, row_stats, col_stats, coo_tensor + + +def get_special_format_str(): + major, minor = torch.cuda.get_device_capability() + if major < 7: + print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + assert major >= 7 + + if major == 7: return 'col_turing' + elif major == 8: return 'col_ampere' + else: return 'col_turing' + + + + +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + else: + dim1 = ct.c_int32(shape[0]*shape[1]) + dim2 = ct.c_int32(shape[2]) + + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + if to_order == 'col32': + if transpose: + lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == 'col_turing': + if transpose: + lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == 'col_ampere': + if transpose: + lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == 'row': + if from_order == 'col_turing': + lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) + elif from_order == 'col_ampere': + lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + else: + raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + + + + + return out, new_state + +def spmm_coo(cooA, B, out=None): + if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + nnz = cooA.nnz + assert cooA.rowidx.numel() == nnz + assert cooA.colidx.numel() == nnz + assert cooA.values.numel() == nnz + assert cooA.cols == B.shape[0] + + transposed_B = (False if B.is_contiguous() else True) + + ldb = B.stride()[(1 if transposed_B else 0)] + ldc = B.shape[1] + + ptr = Cusparse_Context.get_instance().context + + ptrRowidx = get_ptr(cooA.rowidx) + ptrColidx = get_ptr(cooA.colidx) + ptrValues = get_ptr(cooA.values) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + cnnz = ct.c_int32(cooA.nnz) + crowsA = ct.c_int32(cooA.rows) + ccolsA = ct.c_int32(cooA.cols) + ccolsB = ct.c_int32(B.shape[1]) + cldb = ct.c_int32(ldb) + cldc = ct.c_int32(ldc) + + lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + + return out + +def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): + if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) + nnz = cooA.nnz + assert cooA.rowidx.numel() == nnz + assert cooA.colidx.numel() == nnz + assert cooA.values.numel() == nnz + assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}' + + transposed_B = (False if B.is_contiguous() else True) + + ldb = B.stride()[(1 if transposed_B else 0)] + ldc = B.shape[1] + + values, counts = torch.unique(cooA.rowidx, return_counts=True) + offset = counts.cumsum(0).int() + max_count, max_idx = torch.sort(counts, descending=True) + max_idx = max_idx.int() + max_count = max_count.int() + assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.' + assert B.dtype in [torch.float16, torch.int8] + ptrOffset = get_ptr(offset) + ptrMaxCount = get_ptr(max_count) + ptrMaxIdx = get_ptr(max_idx) + + ptrRowidx = get_ptr(cooA.rowidx) + ptrColidx = get_ptr(cooA.colidx) + ptrValues = get_ptr(cooA.values) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrDequantStats = get_ptr(dequant_stats) + cnnz_rows = ct.c_int32(counts.numel()) + cnnz = ct.c_int32(cooA.nnz) + crowsA = ct.c_int32(cooA.rows) + ccolsA = ct.c_int32(cooA.cols) + crowsB = ct.c_int32(B.shape[1]) + ccolsB = ct.c_int32(B.shape[1]) + cldb = ct.c_int32(ldb) + cldc = ct.c_int32(ldc) + #print(cooA.rowidx[:64]) + #print(cooA.colidx[:64].sort()[0]) + + if B.dtype == torch.float16: + lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) + elif B.dtype == torch.int8: + lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) + #else: assertion error + + return out + + +C = 127.0 + +def vectorwise_quant(x, dim=1, quant_type='vector'): + if quant_type == 'linear': + max1 = torch.abs(x).max().float() + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type in ['vector', 'row']: + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x*(C/max1)).to(torch.int8) + return xq, max1 + elif quant_type == 'zeropoint': + dtype = x.dtype + x = x.float() + dyna = x.max() - x.min() + if dyna == 0: dyna = 1 + qx = 255./dyna + minx = x.min() + zpx = torch.round(minx* qx) + x = torch.round(qx*x - zpx) + zpx + return x, qx + elif quant_type in ['vector-zeropoint', 'row-zeropoint']: + dtype = x.dtype + x = x.float() + dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)) + dyna[dyna==0] = 1 + qx = 255./dyna + minx = torch.amin(x, dim=dim, keepdim=True) + zpx = torch.round(minx* qx) + x = torch.round(qx*x - zpx) + zpx + return x, qx + elif quant_type == 'truncated-vector': + with torch.no_grad(): + absx = torch.abs(x) + max1 = torch.amax(absx, dim=dim, keepdim=True) + max1 = max1*0.7 + idx = (absx > max1.expand_as(absx)) + sign = torch.sign(x[idx]) + x[idx] = max1.expand_as(absx)[idx]*sign + xq = torch.round(x/max1*C).to(torch.int8) + return xq, max1 + else: return None + +def vectorwise_dequant(xq, max1, quant_type='vector'): + if quant_type == 'vector': + x = (xq/C*max1).to(torch.float32) + return x + else: return None + +def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'): + if quant_type == 'linear': + norm = S1*S2/(C*C) + # double cast needed to prevent overflows + return (xq.float()*norm).to(dtype) + elif quant_type == 'zeropoint': + norm = 1.0/(S1*S2) + return (xq.float()*norm).to(dtype) + elif quant_type == 'row-zeropoint': + norm = 1.0/(S1*S2) + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= norm + else: + x *= norm + return x.to(dtype) + elif quant_type == 'vector-zeropoint': + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= 1.0/S1 + else: + x *= 1.0/S1 + x *= 1.0/S2.t() + return x.to(dtype) + elif quant_type == 'row': + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1*S2/(C*C) + else: + x *= S1*S2/(C*C) + return x.to(dtype) + elif quant_type in ['truncated-vector', 'vector']: + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1/C + else: + x *= S1/C + x *= S2/C + return x.to(dtype) + else: return None + + +def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): + offset = B.float().t().sum(0)*(SA[0]+SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t()/127 + else: + x *= SB/127 + x *= SA[1]/127 + x +=offset + return x.to(dtype) diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 27ad6ca..03b4655 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import StableEmbedding, Embedding +from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index c5460fb..5013d0b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -3,14 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch +import bitsandbytes as bnb -from typing import Optional +from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict -from torch import Tensor +from torch import Tensor, device, dtype +from torch import nn +from torch.nn.parameter import Parameter import torch.nn.functional as F from bitsandbytes.optim import GlobalOptimManager +T = TypeVar('T', bound='torch.nn.Module') + class StableEmbedding(torch.nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, @@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding): self.norm_type, self.scale_grad_by_freq, self.sparse) return emb + +class Int8Params(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None): + cls.has_fp16_weights = has_fp16_weights + cls.CB = None + cls.SCB = None + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def cuda(self, device): + if self.has_fp16_weights: + return super().cuda(device) + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = self.data.contiguous().half().cuda(device) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.data = CB + setattr(self, 'CB', CB) + setattr(self, 'SCB', SCB) + + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device) + else: + new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights) + new_param.CB = self.CB + new_param.SCB = self.SCB + + return new_param + + + +class Linear8bitLt(nn.Linear): + def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None): + super(Linear8bitLt, self).__init__(input_features, output_features, bias) + self.state = bnb.MatmulLtState() + self.index=index + + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x): + self.state.is_training = self.training + + if self.weight.CB is not None: self.init_8bit_state() + #assert not self.state.has_fp16_weights + #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None + + out = bnb.matmul(x, self.weight, state=self.state) + + if self.bias is not None: + out += self.bias.unsqueeze(0).expand_as(out) + + if not self.state.has_fp16_weights and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + + return out + +class Linear8bit(nn.Linear): + def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False): + super(Linear8bit, self).__init__(input_features, output_features, bias) + self.quant_type = quant_type + self.index = index + self.args = args + self.iter = 0 + + def forward(self, x): + self.iter += 1 + if self.iter % self.args.clip_freq == 0: + with torch.no_grad(): + maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx) + if not dist.is_initialized() or dist.get_rank() == 0: + print('clip', maxval[-1].item()) + self.weight.clip_(-maxval[-1], maxval[-1]) + + + if self.args is not None: + out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type) + else: + out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type) + + return out diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d0aabff..1c3e723 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char } } +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockReduce BlockRowReduce; + typedef cub::BlockReduce BlockRowSum; + typedef cub::BlockExchange BlockExchange; + + __shared__ union { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + } temp_storage; + + __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; + __shared__ int smem_row_nnz_values[TILE_ROWS]; + //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS]; + + half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + __syncthreads(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + __syncthreads(); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = fabsf(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + __syncthreads(); + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); + if(SPARSE_DECOMP) + { + __syncthreads(); + local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(threadIdx.x == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + __syncthreads(); + + } + + // 4. store data via atomicMax + // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arangement: [0, 8, 16, 24, ..] for t0 + __syncthreads(); + BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+threadIdx.x+(j*THREADS) < cols) + { + float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+threadIdx.x+(j*THREADS) < rows) + { + float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; + if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) + atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(threadIdx.x < TILE_ROWS) + nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; + +} + +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + __shared__ float smem_rowStats[SUBTILE_ROWS]; + + typedef cub::BlockLoad LoadInt32; + typedef cub::BlockExchange ExchangeInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + __shared__ typename ExchangeInt32::TempStorage exchangeint32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + float colStat = col >= numCols ? 0.0f : colStats[col]; + // no block loads for rows for now -- keep it simple + for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + } + __syncthreads(); + + + // each block processes SUBTILE_ROWS*32 elements + const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat); + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + } +} + + +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadHalf; + __shared__ typename LoadHalf::TempStorage loadhalf; + typedef cub::BlockStore StoreInt8; + __shared__ typename StoreInt8::TempStorage storeint8; + + __shared__ float smem_row_stats[TILE_ROWS]; + __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); + + for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; + } + __syncthreads(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = __fdividef(127.0f, smem_row_stats[row]); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(fabsf((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + } + else + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); + } + + __syncthreads(); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + + } +} + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; + char local_data[ITEMS_PER_THREAD]; + typedef cub::BlockExchange BlockExchange; + __shared__ typename BlockExchange::TempStorage temp_storage; + + // we load row after row from the base_position + // Load data row by row + int warps = blockDim.x/32; + int warp_id = threadIdx.x/32; + int warp_lane = threadIdx.x % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + __syncthreads(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consequtive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happends every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define C 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]); + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx); + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + //if((float)local_valsB[k] != 0.0) + // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB); + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 0a3676c..cbfbeba 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -106,6 +106,18 @@ template __global__ void kPercentileCl __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 40c185c..8946015 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -188,11 +189,416 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, CUDA_CHECK_RETURN(cudaPeekAtLastError()); } +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + cublasStatus_t status; + + status = cublasGemmEx(context->m_handle, + transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, + transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + m, n, k, + alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, + C, CUDA_R_32I, ldc, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + if (status != CUBLAS_STATUS_SUCCESS) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + cublasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = cublasGemmStridedBatchedEx(context->m_handle, + transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, + transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + m, n, k, + alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, + C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT); + + if (status != CUBLAS_STATUS_SUCCESS) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +template cublasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return CUBLASLT_ORDER_ROW; + break; + case COL: + return CUBLASLT_ORDER_COL; + break; + case COL32: + return CUBLASLT_ORDER_COL32; + break; + case COL_TURING: + return CUBLASLT_ORDER_COL4_4R2_8C; + break; + case COL_AMPERE: + return CUBLASLT_ORDER_COL32_2R_4R4; + break; + } +} + +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +{ + + cublasLtOrder_t orderA = get_order(); + cublasLtOrder_t orderOut = get_order(); + int ldA = get_leading_dim(dim1, dim2); + int ldOut = get_leading_dim(dim1, dim2); + + cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; + cublasLtMatrixTransformDesc_t A2Out_desc = NULL; + cublasOperation_t opTranspose = CUBLAS_OP_T; + float transformAlpha = 1.0f, transformBeta = 0.0f; + + + if(DTYPE == 8) + { + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut)); + } + else if(DTYPE == 32) + { + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut)); + } + else + { + printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); + } + + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + + checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F)); + + if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } + + checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); + + if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); + if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); + if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +} + +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ + int has_error = 0; + cublasLtMatmulDesc_t matmulDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasOperation_t opT = CUBLAS_OP_T; + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; + + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); + + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(FORMATB == COL_TURING) + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + else + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + if(DTYPE_OUT == 32) + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + } + + + if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + if(has_error == 1) + printf("error detected"); + + return has_error; +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) +{ + int threads = 512; + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + + //cout << num_blocks << " blocks" << endl; + + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + int tile_cols = STATS_THREADS*STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS); + + if(nnz_threshold == 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + +} + +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + + //cout << cols << " " << tiledCols << " " << tiledRows << endl; + //cout << "num blocks " << num_blocks << endl; + + //cout << A << " " << out_col_normed << endl; + if(threshold > 0.0f) + kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + else + kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows); + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + + //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl; + //cout << "num blocks " << num_blocks << endl; + + //cout << A << " " << out_col_normed << endl; + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + + cusparseSpMatDescr_t descA; + cusparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); + // Create dense matrix C + CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_CUSPARSE( cusparseSpMM_bufferSize( + handle, + CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, CUDA_R_32F, + CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_CUSPARSE( cusparseSpMM(handle, + CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, CUDA_R_32F, + CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); + CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); + CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + kspmm_coo_very_sparse_naive<<>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 8fb4cec..4e719df 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -14,6 +14,11 @@ #include #include +#include +#include +#include +#include +#include #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ @@ -25,6 +30,34 @@ #define THREADS_PER_BLOCKS (512) +#define CHECK_CUSPARSE(value) { \ + cusparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + +#define THREADS_PER_BLOCKS (512) + + +inline void checkCudaStatus(cudaError_t status) { + if (status != cudaSuccess) { + printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); + throw std::logic_error("cuda API failed"); + } +} + +inline int checkCublasStatus(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + printf("cuBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + typedef enum Operations_t { ksmul = 0, @@ -39,6 +72,57 @@ typedef enum Optimizer_t ADAGRAD = 4, } Optimizer_t; +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +class Context +{ + public: + cublasHandle_t m_handle; + + Context() + { + cublasHandle_t handle; + cublasCreate_v2(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + cublasLtHandle_t m_handle; + + ContextLt() + { + cublasLtHandle_t handle; + cublasLtCreate(&handle); + m_handle = handle; + } + +}; + +class ContextCusparse +{ + public: + cusparseHandle_t m_handle; + + ContextCusparse() + { + cusparseHandle_t handle; + cusparseCreate(&handle); + m_handle = handle; + } + +}; + template void estimateQuantiles(T *A, float *code, float offset, int n); @@ -70,4 +154,24 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols); +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, + int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index c2fed6b..03c8d92 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } #endif +#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ +void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ +{ \ + transform(ltHandle, A, out, dim1, dim2); \ +} \ + +MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); +MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); +MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32); +MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8); +MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8); +MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32); + +void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } + + int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + +void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + +void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + extern "C" { #if BUILD_CUDA @@ -155,7 +201,86 @@ extern "C" void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } - #endif + void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) + { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } + void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long strideA, long strideB, long strideC, int batchCount) + { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); } + + Context *get_context(){ return new Context(); } + ContextCusparse *get_cusparse(){ return new ContextCusparse(); } + + int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + //{ (cublasLtHandle_t)context->m_handle; return 0; } + //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ + void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ + { \ + transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ + } \ + + MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) + MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32) + MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8) + MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) + + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); } + void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) + { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } + + void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) + { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); } + + void ctransform_row2col32(char * A, char *out, int rows, int cols) + { transform_row2col32(A, out, rows, cols); } + + void ctransform_row2col32T(char * A, char *out, int rows, int cols) + { transform_row2col32T(A, out, rows, cols); } + + void ctransform_row2turing(char * A, char *out, int rows, int cols) + { transform_row2turing(A, out, rows, cols); } + + void ctransform_row2turingT(char * A, char *out, int rows, int cols) + { transform_row2turingT(A, out, rows, cols); } + + void ctransform_row2ampere(char * A, char *out, int rows, int cols) + { transform_row2ampere(A, out, rows, cols); } + + void ctransform_row2ampereT(char * A, char *out, int rows, int cols) + { transform_row2ampereT(A, out, rows, cols); } + + void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) + { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } + + void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) + { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + + void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) + { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + +#endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } } diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 0000000..d2b5d59 --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,270 @@ +import pytest + +import torch +import bitsandbytes as bnb + +from itertools import product + +n = 1 +k = 25 +dim1 = torch.randint(16,64, size=(n,)).tolist() +dim2 = torch.randint(32,96, size=(n,)).tolist() +dim3 = torch.randint(32,96, size=(n,)).tolist() +dim4 = torch.randint(32,96, size=(n,)).tolist() +funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] +str_funcs = ['bmm', 'matmul'] +req_grad = [(False, False), (True, False), (True, True), (False, True)] +req_grad_str = ['FF', 'TF', 'TT', 'FT'] +transpose = [(False, False), (False, True), (True, True), (True, False)] +str_transpose = ['FF', 'FT', 'TT', 'TF'] +dtype = [torch.float32, torch.float16] +values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose)) +str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0]) + B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) + target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1]) + torch.nn.init.xavier_uniform_(B) + + if not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + elif not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t()) + elif transpose[0] and not transpose[1]: + out_torch = funcs[0](A.t(), B) + out_bnb = funcs[1](A.t(), B) + elif transpose[0] and transpose[1]: + out_torch = funcs[0](A.t(), B.t()) + out_bnb = funcs[1](A.t(), B.t()) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx==0).sum().item() < n*0.001 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + + # batched matrix multiply + if funcs[0] in [torch.bmm, torch.matmul]: + A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) + B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1]) + target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + torch.nn.init.xavier_uniform_(B) + + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.01 + torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + + if funcs[0] in [torch.matmul]: + dim1 = dim1 - (dim1 % 16) + A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) + dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) + B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) + target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + torch.nn.init.xavier_uniform_(B) + + if transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t()) + else: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx==0).sum().item() < n*0.001 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + + +n = 1 +k = 3 +dim1 = torch.randint(16,64, size=(n,)).tolist() +dim2 = torch.randint(32,96, size=(n,)).tolist() +dim3 = torch.randint(32,96, size=(n,)).tolist() +dim4 = torch.randint(32,96, size=(n,)).tolist() + +#dim1 = (17,) +#dim2 = (7,) +#dim3 = (37,) +#dim4 = (23,) + +decomp = [0.0, 6.0] +funcs = [(torch.matmul, bnb.matmul)] +str_funcs = ['matmul'] +req_grad = [(False, False), (True, False), (True, True), (False, True)] +req_grad_str = ['FF', 'TF', 'TT', 'FT'] +transpose = [(False, True), (False, False)] +str_transpose = ['NT', 'NN'] +dtype = [torch.float16] +has_fp16_weights = [True, False] +values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights)) +str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names) +def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda') + + for i in range(k): + + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype) + if decomp == 6.0: + with torch.no_grad(): + A[:, outlier_dim] = 6.0 + B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype) + torch.nn.init.xavier_uniform_(B) + B2 = B.clone() + + state = bnb.MatmulLtState() + state.threshold = decomp + state.has_fp16_weights = has_fp16_weights + if not has_fp16_weights: + if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() + state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2) + B2 = state.CB + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2, state=state) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2.t(), state=state) + + n = out_bnb.numel() + err = torch.abs(out_bnb-out_torch).mean().item() + #print(f'abs error {err:.4f}') + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx==0).sum().item() < n*0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx==0).sum().item() < n*0.001 + + if has_fp16_weights: + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + if req_grad[1]: + n = gradB1.numel() + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx==0).sum().item() < n*0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx==0).sum().item() < n*0.02 + torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + diff --git a/tests/test_functional.py b/tests/test_functional.py index 2a7d308..6cbe58f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,15 +1,76 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import pytest +import math +import random +import time import torch import bitsandbytes as bnb +import einops from itertools import product from bitsandbytes import functional as F +torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) +k = 20 + +def assert_all_approx_close(a, b, rtol, atol, count): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx==0).sum().item() + if sumval > count: + print(f'Too many values not close: assert {sumval} < {count}') + torch.testing.assert_allclose(a, b, rtol, atol) + +class FFN(torch.nn.Module): + def __init__(self, input_features, hidden_size, bias=True): + super(FFN, self).__init__() + self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) + self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) + + with torch.no_grad(): + torch.nn.init.xavier_uniform_(self.fc1.weight) + torch.nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = self.fc2(x) + return x + +class Timer(object): + def __init__(self): + self.starts = {} + self.ends = {} + self.agg = {} + + def tick(self, name='default'): + if name not in self.starts: + self.starts[name] = torch.cuda.Event(enable_timing=True) + self.ends[name] = torch.cuda.Event(enable_timing=True) + self.starts[name].record() + else: + ms = self.tock(name, evict=True, print_ms=False) + + def tock(self, name='default', evict=True, print_ms=True): + if name in self.ends: + self.ends[name].record() + torch.cuda.synchronize() + ms = self.starts[name].elapsed_time(self.ends[name]) + if name not in self.agg: self.agg[name] = 0.0 + self.agg[name] += ms + if evict: + self.starts.pop(name) + self.ends.pop(name) + + if print_ms and name in self.agg: + print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0)) + + return self.agg[name] + + def reset(self): + self.starts = {} + self.ends = {} + self.agg = {} + print('Resetting benchmark data') + def setup(): pass @@ -64,8 +125,8 @@ def test_dynamic_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + #print(sum(diffs)/len(diffs)) + #print(sum(reldiffs)/len(reldiffs)) for i in range(100): A1 = torch.rand(1024, 1024, device='cuda') @@ -88,8 +149,8 @@ def test_dynamic_blockwise_quantization(): diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diffs[-1] < 0.011 - print(sum(diffs)/len(diffs)) - print(sum(reldiffs)/len(reldiffs)) + #print(sum(diffs)/len(diffs)) + #print(sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -125,7 +186,7 @@ def test_percentile_clipping(gtype): n = 4 step = 0 percentile=5 - for i in range(1000): + for i in range(k): step += 1 g = torch.randn(n, n, dtype=gtype, device='cuda') gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) @@ -145,69 +206,1653 @@ def test_percentile_clipping(gtype): torch.testing.assert_allclose(gnorm1, gnorm2) +def quant(x): + max1 = torch.abs(x).max() + x = torch.round(x/max1*127) + return max1, x.to(torch.int8) + +def dequant(c, maxC): + return c.float()*(maxC/127) + +def mm_dequant(maxA, maxB, C): + return C.float()*(maxA/127)*(maxB/127) + +def quant_multi(x, dim): + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1==0] = 1.0 + x = torch.round(x/max1*127) + return max1, x.to(torch.int8) + +def quant_multi_chunk(x, dim, chunk_size=32): + if dim==1: + x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True) + max1 = torch.tile(max1, (1, 1, x.shape[1])) + max1 = max1.view(x.shape) + elif dim==0: + x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) + max1 = torch.tile(max1, (x.shape[0], 1, 1)) + max1 = max1.view(x.shape) + max1[max1==0] = 1.0 + x = torch.round(x/max1*127) + return max1, x.to(torch.int8) + +def quant_minmax(A): + minA = A.min() + maxA = A.max() + +def mean(xx): + return sum(xx)/float(len(xx)) + +#dim1 = torch.randint(1,1024*4, size=(4,)).tolist() +#dim2 = torch.randint(1,1024*4, size=(4,)).tolist() +dim1 = [1024*2] +dim2 = [1024*16] +methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)] +methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) +#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) +method_names = ['linear', 'vectorwise'] +batched = [False, True] +values = list(product(dim1,dim2, methods, batched)) +values_names = list(product(dim1,dim2, method_names, batched)) +names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names] +@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) +def test_approx_igemm(dim1, dim2, quant_methods, batched): + dim1 = dim1 - (dim1 % 32) + dim2 = dim2 - (dim2 % 32) + errors = [] + relerrors = [] + print('') + for i in range(5): + if batched: + A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda') + B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda') + maxA, Ac = quant_methods[0](A, 2) + maxB, Bc = quant_methods[1](B, 1) + else: + A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda') + B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda') + maxA, Ac = quant_methods[0](A, 1) + maxB, Bc = quant_methods[1](B, 0) + torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) + if batched: + out2 = torch.bmm(A, B) + C = torch.bmm(Ac.float(), Bc.float()) + else: + out2 = torch.mm(A, B) + C = F.igemm(Ac, Bc) + out = quant_methods[4](maxA, maxB, C) + std = out2.std() + out/= std + out2/= std + err = torch.abs(out-out2) + relerr = err/torch.abs(out2) + errors.append(err.mean().item()) + relerrors.append(relerr.mean().item()) + print(mean(errors)) + print(mean(relerrors)) + + + + + + def test_stable_embedding(): layer = bnb.nn.StableEmbedding(1024, 1024) layer.reset_parameters() -def test_dynamic_blockwise_quantization_cpu(): - #A1 = torch.randn(1024, 1024, device='cpu') - #code = F.create_dynamic_map() - #for i in range(1000): - # C, S = F.quantize_blockwise(A1, code=code) - # A2 = F.dequantize_blockwise(C, S) - for i in range(10): - # equivalence with GPU blockwise quantization - A1 = torch.randn(1024, 1024, device='cpu') - C1, S1 = F.quantize_blockwise(A1) - C2, S2 = F.quantize_blockwise(A1.cuda()) - torch.testing.assert_allclose(S1[0], S2[0].cpu()) - # there seems to be some issues with precision in CUDA vs CPU - # not all elements are usually close, with couple off elements in a million - idx = torch.isclose(C1, C2.cpu()) - assert (idx==0).sum().item() < 15 +n = 2 +hidden_dim = torch.randint(32,256, size=(n,)).tolist() +batch_dim = torch.randint(16,256, size=(n,)).tolist() +seq_dim = torch.randint(16,256, size=(n,)).tolist() +transpose = [(False, False), (False, True), (True, False), (True, True)] +values = list(product(hidden_dim,batch_dim, transpose, seq_dim)) +names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) +def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 16) + seq_dim = seq_dim - (seq_dim % 16) + for i in range(k): + shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + elif transpose[0] and not transpose[1]: + out2 = torch.matmul(A.t().float(), B.float()) + out = F.igemm(A.t(), B) + elif transpose[0] and transpose[1]: + out2 = torch.matmul(A.t().float(), B.t().float()) + out = F.igemm(A.t(), B.t()) + torch.testing.assert_allclose(out.float(), out2) - diffs = [] - reldiffs = [] + for i in range(k): + shapeA = (batch_dim, seq_dim, hidden_dim) + shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + + torch.testing.assert_allclose(out.float(), out2) + + +n = 3 +seq_dim = torch.randint(32,512, size=(n,)).tolist() +hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() +batch_dim = torch.randint(2,16, size=(n,)).tolist() +values = list(product(seq_dim,hidden_dim,batch_dim)) +names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) +def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): + seq_dim = seq_dim - (seq_dim % 32) + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 2) + for i in range(25): + A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8) + out2 = torch.einsum('bsi, bso->io', A.float(), B.float()) + iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) + out = F.igemm(A, B, out=iout) + + torch.testing.assert_allclose(out.float(), out2) + +n = 2 +seq_dim = torch.randint(32,512, size=(n,)).tolist() +hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() +batch_dim = torch.randint(2,16, size=(n,)).tolist() +transpose = [False, True] +values = list(product(seq_dim,hidden_dim,batch_dim, transpose)) +names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) +def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): + + def min_max(x): + maxA = torch.amax(x, dim=2, keepdim=True) + minA = torch.amin(x, dim=2, keepdim=True) + scale = (maxA-minA)/2.0 + return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale + + seq_dim = seq_dim - (seq_dim % 16) + hidden_dim = hidden_dim - (hidden_dim % 16) + batch_dim = batch_dim - (batch_dim % 2) + errs = [] + relerrs = [] + errs2 = [] + relerrs2 = [] + for i in range(k): + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda') + if transpose: + B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda') + else: + B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda') + Ac, minA, scale = min_max(A) + if transpose: + maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) + out = F.igemm(Ac, Bc.t()) + out2 = torch.matmul(A,B.t()) + offset = B.t().sum(0)*(minA+scale) + out = out.float() + out = (out*maxB.t()*scale/(127*127))+offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc.t()) + out3 = mm_dequant(maxA, maxB.t(), out3) + else: + maxB, Bc = quant_multi(B, dim=0) + offset = B.sum(0)*(minA+scale) + out = F.igemm(Ac, Bc) + out2 = torch.matmul(A,B) + out = out.float() + out = (out*maxB*scale/(127*127))+offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc) + out3 = mm_dequant(maxA, maxB, out3) + + std = out2.std() + out2 /= std + out /= std + out3 /= std + + err = torch.abs(out-out2) + relerr = err/(torch.abs(out2)+1e-7) + + err2 = torch.abs(out3-out2) + relerr2 = err2/(torch.abs(out2)+1e-7) + + errs.append(err.mean().item()) + relerrs.append(relerr.mean().item()) + errs2.append(err2.mean().item()) + relerrs2.append(relerr2.mean().item()) + #print(mean(errs)) + #print(mean(relerrs)) + #print(mean(errs2)) + #print(mean(relerrs2)) + assert mean(errs) < 0.015 + assert mean(relerrs) < 0.3 + +n = 2 +dim1 = torch.randint(1,64, size=(n,)).tolist() +dim2 = torch.randint(32,128, size=(n,)).tolist() +dim3 = torch.randint(32,256, size=(n,)).tolist() +dim4 = torch.randint(32,256, size=(n,)).tolist() +transpose = [(False, False), (True, False), (False, True), (True, True)] +values = list(product(dim1,dim2,dim3,dim4,transpose)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) +def test_ibmm(dim1, dim2, dim3, dim4, transpose): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) + shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) + A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + + if not transpose[0] and not transpose[1]: + out2 = torch.bmm(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A, B.permute([0, 2, 1])) + elif transpose[0] and not transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) + out = F.igemm(A.permute([0, 2, 1]), B) + elif transpose[0] and transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) + torch.testing.assert_allclose(out.float(), out2.float()) + +n = 1 +dim1 = torch.randint(1,64, size=(n,)).tolist() +dim2 = torch.randint(32,128, size=(n,)).tolist() +dim3 = torch.randint(32,256, size=(n,)).tolist() +values = list(product(dim1,dim2,dim3)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) +def test_vector_quant(dim1, dim2, dim3): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + for i in range(k): + A = torch.randn(size=(dim2, dim3), device='cuda') + qA, SA = F.vectorwise_quant(A, dim=0) + A1 = F.vectorwise_dequant(qA, SA) + torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1) + + + +n = 2 +dim1 = torch.randint(2,256, size=(n,)).tolist() +dim2 = torch.randint(2,256, size=(n,)).tolist() +dim3 = torch.randint(2,256, size=(n,)).tolist() +#dim1, dim2 = (256,), (256,) +dtype = [torch.int8, torch.int32] +a_order = ['row'] +out_order = ['col', 'row', 'col32'] +transpose = [False] +dims = [2, 3] +values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) + +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + if dims == 3 and out_order != 'col32': return + if dtype == torch.int32 and out_order != 'col32': return + func = F.get_transform_func(dtype, orderA, orderOut, transpose) + + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype) + + out, S = F.nvidia_transform(A, to_order=orderOut) + + if orderOut == 'row': + torch.testing.assert_allclose(A.flatten(), out.flatten()) + elif orderOut == 'col': + torch.testing.assert_allclose(A.t().flatten(), out.flatten()) + elif orderOut == 'col32': + if dims == 2: + n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32))) + elif dims == 3: + n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32))) + assert out.numel() == n + elif orderOut == 'col_turing': + # 32 col 8 row tiles + n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32))) + assert out.numel() == n + total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) + for row in range(A.shape[0]): + for col in range(A.shape[1]): + i = row*A.shape[1] + j = col + + coltile = (col // 32) + (1 if col % 32 != 0 else 0) + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile + offset = 32*8*(rowtile+coltile) + col2 = col % 32 + row2 = (row%8)*32 + + + assert A.flatten()[i+j] == A[row, col] + #assert A.flatten()[i+j] == out.flatten()[row2+col2] + #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) + #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) + + if orderOut == 'col32': + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S) + torch.testing.assert_allclose(A, out2) + + +n = 1 +dim1 = torch.randint(1,256, size=(n,)).tolist() +dim2 = torch.randint(32,512, size=(n,)).tolist() +dim3 = torch.randint(32,1024, size=(n,)).tolist() +dim4 = torch.randint(32,1024, size=(n,)).tolist() + +#dim1 = [2] +#dim2 = [2] +#dim3 = [2] +#dim4 = [2] + +dims = (2,3) +ldb = [0] +#ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) +def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): + for i in range(k): + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + C1 = torch.matmul(A.float(), B.t().float()) + + A2, SA = F.transform(A, 'col32') + B2, SB = F.transform(B, 'col_turing') + C2, SC = F.igemmlt(A2, B2, SA, SB) + C3, S = F.nvidia_transform(C2, 'row', state=SC) + torch.testing.assert_allclose(C1, C3.float()) + + # transpose + B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + C1 = torch.matmul(A.float(), B.float()) + + B2t, SBt = F.transform(B, 'col_turing', transpose=True) + C2, SC = F.igemmlt(A2, B2t, SA, SBt) + C3, S = F.nvidia_transform(C2, 'row', state=SC) + torch.testing.assert_allclose(C1, C3.float()) + +dim1 = [32] +dim2 = [32] +dim3 = [32] +dim4 = [32] + +dims = (2,) +#ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1,dim2,dim3,dim4,dims)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) +def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): + formatB = F.get_special_format_str() + for i in range(k): + if dims == 2: + A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half() + elif dims == 3: + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half() + B = torch.randn((dim4, dim3), device='cuda').half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + C2 = bnb.matmul(A, B.t()) + + A = A.view(-1, A.shape[-1]) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + C32A, SA = F.transform(CA, 'col32') + CxB, SB = F.transform(CB, to_order=formatB) + out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) + output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) + + #print('') + #print(output.flatten()[:10]) + #print(C1.flatten()[:10]) + #print(C2.flatten()[:10]) + + + #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + + # transpose + #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + #C1 = torch.matmul(A.float(), B.float()) + + #B2t, SBt = F.transform2(B, 'col_turing', transpose=True) + #C2, SC = F.igemmlt(A2, B2t, SA, SBt) + #C3, S = F.transform(C2, 'row', state=SC) + #torch.testing.assert_allclose(C1, C3.float()) + +batch_size = 2 +seqdim = 512 +#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] + + +#values = list(product(batch, seq, model, hidden)) +names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +def test_bench_8bit_training(batch, seq, model, hidden): + formatB = F.get_special_format_str() + A = torch.randn(batch, seq, model, device='cuda').half() + grad = torch.randn(batch, seq, model, device='cuda').half() + w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half() + w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half() + print('') + + #torch.cuda.synchronize() + ## warmup + #for i in range(100): + # torch.matmul(A, w1.t()) + #torch.cuda.synchronize() + + dtype = torch.int8 + A = A.view(-1, A.shape[-1]).contiguous() + grad = grad.view(-1, grad.shape[-1]).contiguous() + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + + out1 = torch.matmul(A, w1.t()) # fc1 + #out2 = torch.matmul(out1, w2.t())# fc2 + + #d1 = torch.matmul(grad, w2) # delta1 + #d2 = torch.matmul(d1, w1) # delta2 + + #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 + #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 + + torch.cuda.synchronize() + t16 = time.time() - t0 + print(t16) + + #torch.cuda.empty_cache() + + #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + #CTw1, Sw1 = F.transform2(Cw1, formatB) + #CTw2, Sw2 = F.transform2(Cw2, formatB) + #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + + #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + #C32A, SA = F.transform2(CA, 'col32') + ## fc1 + #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) + + ## fc2 + #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + #C32out1, Sout1 = F.transform2(Cout1, 'col32') + #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) + + ## delta1 + #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + #C32grad, Sgrad = F.transform2(Cgrad, 'col32') + ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) + + ## delta2 + #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + #C32d1, Sd1 = F.transform2(Cd1, 'col32') + ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) + + ## grad1 + #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) + + ## grad2 + #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) + + #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + #CTw1, Sw1 = F.transform2(Cw1, formatB) + #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + #CTw2, Sw2 = F.transform2(Cw2, formatB) + #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(k): + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + + # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5) + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # #CTw2, Sw2 = F.transform2(Cw2, formatB) + # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + + # C32A, SA = F.transform2(CA, 'col32') + + # # fc1 + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + # #print(coo_tensor.nnz) + # #out1sp = F.spmm_coo(coo_tensor, w1.t()) + # #print(w1.t().shape) + # #out1 = out1dn + out1sp + + # # fc2 + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2) + + # # delta1 + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') + # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t) + + # # delta2 + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') + # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t) + + # # grad1 + # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt) + + # ## grad2 + # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) + + #torch.cuda.synchronize() + #t8 = time.time() - t0 + #print(t8) + + + + + +n = 2 +dim1 = torch.randint(64,256, size=(n,)).tolist() +dim4 = torch.randint(64,1024, size=(n,)).tolist() + +#dim1 = [2*1024] +#dim4 = [2*1024] + +#dim1 = [4] +#dim4 = [4] + +dims = (2,) +#ldb = list(range(256, 1*1024, 256)) +formatB = ['col_turing', 'col_ampere'] +values = list(product(dim1,dim4,dims, formatB)) +names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) +def test_dequant_mm(dim1, dim4, dims, formatB): + inner = torch.randint(1, 128, size=(1,)).item() + formatB = F.get_special_format_str() + for i in range(k): + A = torch.randn(dim1, inner, device='cuda') + B = torch.randn(dim4, inner, device='cuda') + C1 = torch.matmul(A.half(), B.t().half()) + + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + A2, SA = F.nvidia_transform(A1, 'col32') + B2, SB = F.nvidia_transform(B1, formatB) + C2, SC = F.igemmlt(A2, B2, SA, SB) + + C3, S = F.nvidia_transform(C2, 'row', state=SC) + C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + + count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() + n = C1.numel() + p = 0.06 + assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}' + + C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) + torch.testing.assert_allclose(C5, C4) + #print(C2) + + + +n = 2 +dim1 = [1*1024] +dim2 = [1*1024] +#dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() + +dims = (2,) +#ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1,dim2,dims)) +names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) +def test_colrow_absmax(dim1, dim2, dims): + for i in range(k): + threshold = 3.0 + A = torch.randn(dim1, dim2, device='cuda').half() + A_truncated = A.clone() + A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 + if dims == 2: + row_stats1, _ = torch.abs(A.float()).max(1) + col_stats1, _ = torch.abs(A.float()).max(0) + row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) + col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) + else: + assert False + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + + A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4) + nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten() + nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device) + nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) + + torch.testing.assert_allclose(col_stats1_trunc, col_stats2) + torch.testing.assert_allclose(row_stats1_trunc, row_stats2) + torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + + torch.testing.assert_allclose(col_stats1, col_stats2) + torch.testing.assert_allclose(row_stats1, row_stats2) + assert nnz_block_ptr2 is None + + + +n = 2 +#dim1 = [8*1024] +#dim2 = [4*1024] +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim2 = torch.randint(1,4*1024, size=(n,)).tolist() + +values = list(product(dim1,dim2)) +names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_double_quant(dim1, dim2): + for i in range(k): + A = torch.randn(dim1, dim2, device='cuda').half() + out_col1, Scol = F.vectorwise_quant(A, dim=0) + out_row1, Srow = F.vectorwise_quant(A, dim=1) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + + # max difference is 1 due to rounding differences + torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) + torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) + + + n = CAt.numel() + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item() + + # allow for 1:500 error due to rounding differences + min_error = 1/500 + if num_not_close_cols > (min_error*n): + print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}') + assert False + if num_not_close_rows > (min_error*n): + print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}') + assert False + + torch.testing.assert_allclose(Srow.flatten(), statsA) + torch.testing.assert_allclose(Scol.flatten(), statsAt) + + +n = 4 +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim4 = torch.randint(1,4*1024, size=(n,)).tolist() +inner = torch.randint(1,4*1024, size=(n,)).tolist() + +dim1 = [6] +dim4 = [4] +inner = [8] + +values = list(zip(dim1, dim4, inner)) +names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_integrated_igemmlt(dim1, dim4, inner): + for i in range(k): + A = torch.randn(dim1, inner, device='cuda').half() + B = torch.randn(dim4, inner, device='cuda').half() + + out1 = torch.matmul(A.half(), B.t().half()) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + torch.testing.assert_allclose(maxA.flatten(), stats1a) + torch.testing.assert_allclose(maxB.flatten(), stats2a) + torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) + torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) + + A2, SA = F.nvidia_transform(C1a, 'col32') + B2, SB = F.nvidia_transform(C2a, 'col_turing') + outC32, SC = F.igemmlt(A2, B2, SA, SB) + out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + + A2, SA = F.nvidia_transform(A1, 'col32') + B2, SB = F.nvidia_transform(B1, 'col_turing') + C2, SC = F.igemmlt(A2, B2, SA, SB) + + C3, S = F.nvidia_transform(C2, 'row', state=SC) + out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + + err1 = torch.abs(out1-out2).mean().item() + err2 = torch.abs(out1-out3).mean().item() + assert err2 <= err1*1.01 + + +n = 6 +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim4 = torch.randint(1,4*1024, size=(n,)).tolist() +inner = torch.randint(1,4*1024, size=(n,)).tolist() + +values = list(zip(dim1, dim4, inner)) +names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_igemmlt_row_scale(dim1, dim4, inner): + formatB = F.get_special_format_str() + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + for i in range(k): + A = torch.randn(dim1, inner, device='cuda').half() + B = torch.randn(dim4, inner, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + + out1 = torch.matmul(A.half(), B.t().half()) + + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') + A2, SA = F.nvidia_transform(C1a, 'col32') + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0*inner*scale + row_scale = torch.ones_like(maxA)/c + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + C3, S = F.nvidia_transform(outC32, 'row', state=SC) + maxval = torch.abs(C3).max() + if maxval == 127: + scale = 1.5 + else: + scale = maxval/120 + out3 = C3*maxA*absmaxB*c/(127*127) + + C4 = torch.matmul(C1a.float(), CB.float().t()) + + + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + outC32, SC = F.igemmlt(A2, B2, SA, SB) + out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + + CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector') + CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear') + + C = torch.matmul(CA.float(), CB.t().float()) + out4 = C*SA*SB/(127*127) + #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) + + #print('='*80) + #print(out1) + #print(out2) + #print(out3) + + #print(out1) + #print(out2) + #print(out3) + err1.append(torch.abs(out1-out2).mean().item()) + err2.append(torch.abs(out1-out3).mean().item()) + err3.append(torch.abs(out1-out4).mean().item()) + + #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) + print('') + print(sum(err1)/len(err1)) + print(sum(err2)/len(err2)) + print(sum(err3)/len(err3)) + + +dim1 = [1024, 2048] +inner = [12288*4, 4096*4] +dim4 = [12288, 4096] + +values = list(zip(dim1, dim4, inner)) +names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_row_scale_bench(dim1, dim4, inner): + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + A = torch.randn(dim1, inner, device='cuda').half() + B = torch.randn(dim4, inner, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + # warmpup + for i in range(k): + C1 = torch.matmul(A, B.t()) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = torch.matmul(A, B.t()) + torch.cuda.synchronize() + print('16', time.time()-t0) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') + A2, SA = F.nvidia_transform(C1a, 'col32') + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0*inner*scale + row_scale = maxA/c + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + torch.cuda.synchronize() + print('row-wise', time.time()-t0) + + + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32, SC = F.igemmlt(A2, B2, SA, SB) + torch.cuda.synchronize() + print('vector-wise', time.time()-t0) + + + + +n = 2 +dim1 = torch.randint(2,1024, size=(n,)).tolist() +dim2 = torch.randint(2,1024, size=(n,)).tolist() +#dim1 = [8*1024] +#dim2 = [4*1024] + +dim3 = [0] +dtype = [torch.int8] +a_order = ['row'] +out_order = ['col32', 'col_turing', 'col_ampere'] +transpose = [False, True] +dims = [2] +values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) +names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + for i in range(k): + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype) + elif dims == 3: + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype) + + A.view(-1)[-1] = -1 + if transpose: + At = A.t().contiguous() + out1, S1 = F.nvidia_transform(At, to_order=orderOut) + else: + out1, S1 = F.nvidia_transform(A, to_order=orderOut) + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) + + assert S1[0][0] == S2[0][0] + assert S1[0][1] == S2[0][1] + #print(out1) + #print(out2) + + torch.testing.assert_allclose(out1, out2) + +n = 2 +#dim1 = torch.randint(2,1024, size=(n,)).tolist() +#dim2 = torch.randint(2,1024, size=(n,)).tolist() +dim1 = [1] +dim2 = [33] + +dtype = [torch.int8] +#a_order = ['col_turing', 'col_ampere'] +a_order = ['col_turing'] +out_order = ['row'] +values = list(product(dim1,dim2,dtype, a_order, out_order)) +names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) +def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): + for i in range(1): + A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype) + + out2, S2 = F.transform(A, to_order=orderA) + A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2) + assert A2.shape[0] == A.shape[0] + assert A2.shape[1] == A.shape[1] + + + print('') + print(A) + print(out2) + print(A2) + + + #torch.testing.assert_allclose(A, A2) + + + + +def test_overflow(): + formatB = F.get_special_format_str() + for i in range(2): + a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) + b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) + + Ca, Sa = F.nvidia_transform(a, 'col32') + Cb, Sb = F.nvidia_transform(b, formatB) + + c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + c2 = torch.matmul(a.float(), b.float().t()) + + +n = 2 +dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +#dim1 = [4] +#dim2 = [5] + +values = list(product(dim1,dim2)) +names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_coo_double_quant(dim1, dim2): + threshold = 3.00 + for i in range(k): + A = torch.randn(dim1, dim2, device='cuda').half() + + idx = (torch.abs(A) >= threshold) + CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + + if coo_tensor is not None: + A1 = A*idx + A2 = torch.zeros_like(A) + A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + torch.testing.assert_allclose(A1, A2) + + A1 = A*(idx==0) + A2 = (CA.float()*statsA.unsqueeze(1)/127).half() + torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2) + +n = 2 +dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +dim2 = torch.randint(1,1*1024, size=(n,)).tolist() +#dim1 = [7] +#dim2 = [11] +transposed_B = [False, True] +values = list(product(dim1,dim2, transposed_B)) +names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) +def test_spmm_coo(dim1, dim2, transposed_B): + threshold = 1.5 + dim3 = torch.randint(32, 128, size=(1,)).item() + #dim3 = 17 + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + if transposed_B: + B = torch.randn(dim3, dim2).cuda().half() + else: + B = torch.randn(dim2, dim3).cuda().half() + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + + if transposed_B: + out2 = F.spmm_coo(cooA, B.t()) + out1 = torch.matmul(A2, B.t()) + else: + out2 = F.spmm_coo(cooA, B) + out1 = torch.matmul(A2, B) + + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) + + + +def test_spmm_bench(): + batch = 2 + model = 1024*1 + hidden = model*4 + seq = 1024 + dim1 = batch*seq + dim2 = model + dim3 = hidden + threshold = 4 + A = torch.randn(dim1, dim2, device='cuda').half() + B = torch.randn(dim2, dim3, device='cuda').half() for i in range(10): - A1 = torch.randn(1024, 1024, device='cpu') - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2) - reldiff = diff/torch.abs(A1+1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - assert diffs[-1] < 0.011 - #print(sum(diffs)/len(diffs)) - #print(sum(reldiffs)/len(reldiffs)) + C1 = bnb.matmul(A, B) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = bnb.matmul(A, B) + torch.cuda.synchronize() + t8 = time.time()-t0 + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + print(nnz/idx.numel()) + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - diffs = [] for i in range(10): - A1 = torch.rand(1024, 1024, device='cpu') - C, S = F.quantize_blockwise(A1) - A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2).mean().item() - assert diff < 0.0033 - diffs.append(diff) - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - #print(sum(diffs)/len(diffs)) + out2 = F.spmm_coo(cooA, B) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + tsp = time.time()-t0 + print(tsp, t8) + print(tsp/t8) + + +n = 2 +dim1 = torch.randint(256,1*1024, size=(n,)).tolist() +dim2 = torch.randint(256,1*1024, size=(n,)).tolist() +values = list(product(dim1,dim2)) +names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_integrated_sparse_decomp(dim1, dim2): + threshold = 3.0 + formatB = 'col_turing' + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + w1 = torch.randn(dim1, dim2).cuda().half() + out1 = torch.matmul(A, w1.t()) + + Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + CTw1, Sw1 = F.transform(Cw1, formatB) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + C32A, SA = F.transform(CA, 'col32') + + out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + C32A, SA = F.transform(CA, 'col32') + + out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + assert coo_tensor is not None + + out4 = F.spmm_coo(coo_tensor, w1.t()) + out5 = out3 + out4 + + err1 = torch.abs(out1-out2).mean().item() + err2 = torch.abs(out1-out5).mean().item() + assert err2 < err1 + + +def test_matmuls(): + a = torch.randn(256, 256).half().cuda() + b = torch.randn(256, 256).half().cuda() + c1 = torch.matmul(a, b) + c2 = bnb.matmul(a, b) + c3 = bnb.matmul(a, b) + + err1 = torch.abs(c1-c2).mean().item() + err2 = torch.abs(c1-c3).mean().item() + assert err1 < 0.2 + assert err2 < 0.2 + + + +n = 2 +#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1*2048] +dim2 = [12288] +#dim1 = [32] +#dim2 = [32] +#dtype = [torch.float16, torch.int8] +dtype = [torch.float16] +out_function = ['zeros', 'ones'] +values = list(product(dim1,dim2, dtype, out_function)) +names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) +def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): + out_func = getattr(torch, out_func) + + threshold = 3.3 + #threshold = 2.8 + #threshold = 0.0 + A = torch.randn(dim1, dim2, device='cuda').half() + if dtype == torch.float16: + B = torch.randn(dim2, dim2*4, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + else: + B = torch.randn(dim2, dim2*4, device='cuda').half() + torch.nn.init.xavier_uniform_(B) + B, SB = F.vectorwise_quant(B, quant_type='linear') + #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + + print('') + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + out1 = torch.matmul(A2.half(), B.half()) + out = out_func(out1.shape, dtype=torch.float16, device=out1.device) + out1 += out.clone() + out2 = F.spmm_coo_very_sparse(cooA, B, out=out) + #print(B) + #print(out1) + #print(out2) + p = 200/(2048*12288*4) + n = out1.numel() + count = math.ceil(p*n) + std = out1.std() + out1 /= std + out2 /= std + assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) + #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) + + idx_col = torch.randint(0, A2.shape[-1], size=(15,)) + + #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) + + #Bt = torch.randn(dim2*4, dim2, device='cuda').half() + #torch.cuda.synchronize() + #t0 = time.time() + #print(A2.shape, B.shape) + #for i in range(100): + # #out3 = F.spmm_coo(cooA, Bt.t()) + # #out2 = F.spmm_coo(cooA, B) + # #out2 = F.spmm_coo_very_sparse(cooA, B) + # #out1 = torch.matmul(A, Bt.t()) + + #torch.cuda.synchronize() + #print(time.time() - t0) + +def test_layout(): + a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16) + a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte() + a2, s2 = F.transform(a1, 'col_turing') + print(a2.shape) + + print(a1.flatten()[8*64:8*64+32]) + for i in range(4): + print(a2.flatten()[i*8*32:i*8*32+32], 0) + + +def test_coo2csr(): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + csrA = F.coo2csr(cooA) + counts = csrA.rowptr[1:] - csrA.rowptr[:-1] + assert counts.numel() == A.shape[0] + + torch.testing.assert_allclose(counts, (A2!=0).sum(1)) + idx = (A2!=0) + torch.testing.assert_allclose(A2[idx], csrA.values) + + +def test_coo2csc(): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + cscA = F.coo2csc(cooA) + counts = cscA.colptr[1:] - cscA.colptr[:-1] + assert counts.numel() == A.shape[1] + + torch.testing.assert_allclose(counts, (A2!=0).sum(0)) + # torch uses row-major -> use transpose to transfer to col-major + idx = (A2.t()!=0) + torch.testing.assert_allclose(A2.t()[idx], cscA.values) + + + +n = 2 +#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1*2048] +#dim2 = [12288] +dim2 = [2048] +#dim1 = [2] +#dim2 = [2] +dtype = [torch.int8] +values = list(product(dim1,dim2, dtype)) +names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) +def test_spmm_coo_dequant(dim1, dim2, dtype): + threshold = 6.0 + #threshold = 2.8 + #threshold = 0.0 + A = torch.randn(dim1, dim2, device='cuda').half() + B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16) + torch.nn.init.xavier_uniform_(B) + Bt = B.t().contiguous() + + + CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + + rowidx = torch.randint(0, A.shape[-1], size=(15,)) + + A[:, rowidx] = 8.0 + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + A2 = A*idx + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out1 = torch.matmul(A2, B.half()) + out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) + out3 = out3*statsBt.half()/127 + + values, counts = torch.unique(cooA.rowidx, return_counts=True) + offset = counts.cumsum(0).int() + max_count, max_idx = torch.sort(counts, descending=True) + print(torch.median(max_count.float())) + + torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) + + p = 200/(2048*12288*4) + n = out1.numel() + count = math.ceil(p*n) + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) + + + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(100): + # out2 = F.spmm_coo_very_sparse(cooA, B) + #torch.cuda.synchronize() + #print('fp16', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + print('cusparse fp16', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt) + torch.cuda.synchronize() + print('int8', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + torch.cuda.synchronize() + print('int8+dequant', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = torch.matmul(A, B) + torch.cuda.synchronize() + print('matmul', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out = out1+out2 + torch.cuda.synchronize() + print('sparse+ matmul', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) + torch.cuda.synchronize() + print('partial matmul', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.cuda.synchronize() + print('partial matmul', time.time() - t0) + +batch_size = 1 +seqdim = 2048 +values = [] +values.append((batch_size, seqdim, 768, 4*768)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 12288, 4*12288)) +names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +def test_bench_matmul(batch, seq, model, hidden): + formatB = F.get_special_format_str() + + A = torch.randn(batch, seq, model, device='cuda').half() + B = torch.empty(hidden, model, dtype=torch.float16, device='cuda') + torch.nn.init.xavier_uniform_(B) + + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit.eval() + + outliers = torch.randint(0, model, size=(5,)).cuda() + A[:, :, outliers] = 8.0 + + linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + linearMixedBit.eval() + + # warmup + for i in range(100): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print('') + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + bnb.matmul(A, B) + torch.cuda.synchronize() + print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + C32A, SA = F.transform(CA, 'col32') + CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + CxB, SB = F.transform(CB, to_order=formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + torch.cuda.synchronize() + print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + BA, statsB = F.vectorwise_quant(B, dim=1) + CxB, SB = F.nvidia_transform(CB, to_order=formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + A2 = A.view(-1, A.shape[-1]).contiguous() + CA, statsA = F.vectorwise_quant(A2, dim=1) + C32A, SA = F.nvidia_transform(CA, 'col32') + out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) + F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + torch.cuda.synchronize() + print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear') + CxB, SB = F.nvidia_transform(CB, to_order=formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + A2 = A.view(-1, A.shape[-1]).contiguous() + CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear') + C32A, SA = F.nvidia_transform(CA, 'col32') + out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) + out = Cout*statsB*statsA*(1.0/(127*127)) + torch.cuda.synchronize() + print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + linear8bit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + linear8bit(A) + torch.cuda.synchronize() + print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + + linearMixedBit(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + linearMixedBit(A) + torch.cuda.synchronize() + print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + + +def test_zeropoint(): + def min_max(x): + maxA = torch.amax(x, dim=1, keepdim=True) + minA = torch.amin(x, dim=1, keepdim=True) + midpoint = (maxA-minA)/2.0 + dyna = 252/(maxA-minA) + #dyna *= 0.98 + x = dyna*x + x = x - torch.round((dyna*(minA+midpoint))) + return x.to(torch.int8), minA, midpoint, dyna + batch = 2 + seq = 2 + model = 4 + hidden = 2*model + #batch = 4 + #seq = 2048 + #model = 1024 + #hidden = 8*model + A = torch.randn(batch*seq, model, device='cuda').half()-0.4 + B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half()) + + #A[0] = 0 + #B[:, 0] = 0 + #A = A*(A>0) + #A[0, 0] = 0 + #A[0, 0] = 6.0 + + Ac, minA, midpoint, dyna = min_max(A) + #print(Ac[0, 0], 'zero') + #print(Ac, Ac.min(), Ac.max()) + Bc, maxB = F.vectorwise_quant(B, quant_type='linear') + out = F.igemm(Ac, Bc) + out2 = torch.matmul(A,B) + offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna + out = out.float() + #print(out.shape, maxB.shape, scale.shape, offset.shape) + norm1 = maxB/127 + C4 = (out/dyna)*norm1+offset + + + B1 = torch.nn.Parameter(B.clone()) + B2 = torch.nn.Parameter(B.clone()) + B3 = torch.nn.Parameter(B.clone()) + B4 = torch.nn.Parameter(B.clone()) + + + C1 = torch.matmul(A, B1) + C2 = bnb.matmul_cublas(A, B2, None, 'linear') + C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint') + C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint') + + err1 = torch.abs(C1-C2).mean().item() + err2 = torch.abs(C1-C3).mean().item() + err3 = torch.abs(C1-C4).mean().item() + print(err1, err2, err3) + #assert err1 > err2 + + loss1 = C1.mean() + loss2 = C2.mean() + loss3 = C3.mean() + loss4 = C4.mean() + + loss1.backward() + loss2.backward() + loss3.backward() + loss4.backward() + + print(B.grad) + print(B1.grad) + print(B2.grad) + print(B3.grad) + print(B4.grad) + err1 = torch.abs(B1.grad-B2.grad).mean().item() + err2 = torch.abs(B1.grad-B3.grad).mean().item() + err3 = torch.abs(B1.grad-B4.grad).mean().item() + print(err1, err2, err3) + + + + +def test_zp(): + def quant_zp(x): + dtype = x.dtype + x = x.float() + dyna = x.max() - x.min() + if dyna == 0: dyna = 1 + qx = 254./dyna + minx = x.min() + #zpx = torch.round(minx* qx) + #zpx = 127 - torch.round(x.max()* qx) + zpx = torch.round(x.min()* qx) - 127 + x = (qx*x) + zpx + return x, qx, zpx + batch = 2 + seq = 512 + model = 1024 + hidden = 4*model + A = torch.randn(batch*seq, model, device='cuda').half()*0.1 + B = torch.randn(model, hidden, device='cuda').half()*0.1 + + + C0 = torch.matmul(A, B) + + + #A, SA = F.vectorwise_quant(A, quant_type='linear') + #B, SB = F.vectorwise_quant(B, quant_type='linear') + A = A.float() + B = B.float() + + C1 = torch.matmul(A, B) + C3 = bnb.matmul(A.half(), B.t().contiguous().half()) + + zp = 1 + #C2 = torch.matmul(A-zp, B) + #C2 += B.sum(0).view(1, -1)*zp + C2 = torch.matmul(A, B-zp) + C2 -= A.sum(1).view(-1, 1)*zp + + ca, cqa, cza = quant_zp(A) + print(ca.min(), ca.max()) + print((ca-cza).min(), (ca-cza).max()) + + zp = 1 + scale = 2.0 + C5 = torch.matmul((A*scale)-zp, B) + C5 += B.sum(0)*zp + C5 /= scale + + CA, qa, zpa = quant_zp(A) + C4 = torch.matmul(CA, B) + C4 -= B.sum(0)*zpa + C4 /= qa + zpb = 1 + zpa = 1 + qa = 2 + qb = 2 + C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb) + C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) + C6 -= zpa*zpb*A.shape[1] + C6 /= qa*qb -def test_histogram(): - dim1, dim2 = 32, 32 - source = torch.rand(dim1, dim2, device='cuda') - idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() - idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int() - histogram1 = torch.zeros((256, 256)).cuda() - histogram2 = torch.zeros((256, 256)).cuda() + CA, qa, zpa = quant_zp(A) + CB, qb, zpb = quant_zp(B) + C7 = torch.matmul(CA, CB) + C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) + C7 -= zpa*zpb*A.shape[1] + C7 /= qa*qb - F.histogram_scatter_add_2d(histogram2, idx1, idx2, source) + print('') + #print(C0.flatten()[:10]) + print(C1.flatten()[:10]) + print(C2.flatten()[:10]) + print(C3.flatten()[:10]) + print(C5.flatten()[:10]) + print(C6.flatten()[:10]) + print(C7.flatten()[:10]) + err1 = torch.abs(C1-C2).mean().item() + err2 = torch.abs(C1-C3).mean().item() + err3 = torch.abs(C1-C4).mean().item() + err4 = torch.abs(C1-C5).mean().item() + err5 = torch.abs(C1-C6).mean().item() + err6 = torch.abs(C1-C7).mean().item() + print(err1, err2, err3, err4, err5, err6) - for i in range(dim1): - for j in range(dim2): - histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j] - torch.testing.assert_allclose(histogram1, histogram2) - torch.testing.assert_allclose(histogram1.sum(), source.sum()) diff --git a/tests/test_modules.py b/tests/test_modules.py index a0379cb..a2c950b 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,42 +1,470 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import pytest import torch + +from itertools import product +from torch import nn + import bitsandbytes as bnb +class MockArgs(object): + def __init__(self, initial_data): + for key in initial_data: + setattr(self, key, initial_data[key]) + +class MLP8bit(torch.nn.Module): + def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + super(MLP8bit, self).__init__() + self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) + self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def get_args(): + args = MockArgs([]) + args.quant_type = 'vector' + args.use_8bit_training = 'full' + args.clip_freq = 9999 + return args + +def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx==0).sum().item() + if sumval > count: + print(f'Too many values not close: assert {sumval} < {count}') + torch.testing.assert_allclose(a, b, rtol, atol) + +class LinearFunction(torch.autograd.Function): + + @staticmethod + def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + norm = math.sqrt(math.pi)/math.sqrt(2.0) + #std = torch.abs(x).mean()*norm + std = torch.std(x) + max1 = std*trim_value + x = x/max1*127 + x = round_func(x) + x[x > 127] = 127 + x[x < -127] = -127 + x = x/127*max1 + + return x + + def quant(x, quant_type, dim=1): + if quant_type == 'linear': + max1 = torch.abs(x).max().float() + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type == 'vector': + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x/max1*127).to(torch.int8) + return xq, max1 + elif quant_type == 'min-max': + maxA = torch.amax(x, dim=dim, keepdim=True).float() + minA = torch.amin(x, dim=dim, keepdim=True).float() + scale = (maxA-minA)/2.0 + xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) + return xq, (minA.float(), scale.float()) + else: return None + + def dequant(xq, S1, S2, dtype, quant_type): + if quant_type == 'linear': + norm = S1*S2/(127*127) + # double cast needed to prevent overflows + return (xq.float()*norm).to(dtype) + elif quant_type == 'vector': + x = xq.float() + if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) + if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) + #print(x.shape, S1.shape, S2.shape) + if len(S1.shape) == 2: + x *= S1.t()/127 + else: + x *= S1/127 + x *= S2/127 + return x.to(dtype) + else: return None + + def dequant_min_max(xq, A, B, SA, SB, dtype): + offset = B.float().t().sum(0)*(SA[0]+SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t()/127 + else: + x *= SB/127 + x *= SA[1]/127 + x +=offset + return x.to(dtype) + + + def get_8bit_linear(x, stochastic=False): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + max1 = torch.abs(x).max() + x = x/max1*127 + x = round_func(x)/127*max1 + #x = torch.round(x)/128*max1 + return x + + @staticmethod + def get_8bit_vector_wise(x, dim, stochastic=False): + round_func = LinearFunction.round_stoachastic if stochastic else torch.round + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1==0] = 1.0 + x = (x*127)/max1 + x = round_func(x)/127*max1 + return x + + @staticmethod + def round_stoachastic(x): + sign = torch.sign(x) + absx = torch.abs(x) + decimal = absx-torch.floor(absx) + rdm = torch.rand_like(decimal) + return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) + + @staticmethod + def fake_8bit_storage(w, exponent_bits): + code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_quantile(w, args): + code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) + #C = bnb.functional.quantize_no_absmax(code, w) + #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) + #print(out) + #out = out.half() + code /= torch.max(torch.abs(code)) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_stoachstic(w): + rand = torch.rand(1024, device=w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand) + out = bnb.functional.dequantize_blockwise(absmax, C) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_with_max(w, topk=8): + blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) + max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) + idx = idx[:, :topk] + max_val = max_val[:, :topk] + + mask = torch.zeros_like(blocked_w) + mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val)) + mask = mask.bool() + + # 1. zero out max values + # 2. quantize + dequantize + # 3. write back max values + # 4. copy matrix back to weight + + values = blocked_w[mask] + blocked_w[mask] = 0 + + code = bnb.functional.create_dynamic_map() + code = code.to(w.device) + absmax, C = bnb.functional.quantize_blockwise(blocked_w.data) + bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w) + + blocked_w[mask] = values + + unblocked_w = blocked_w.flatten().view(w.shape) + + w.copy_(unblocked_w) + return unblocked_w + + + @staticmethod + def forward(ctx, x, weight, bias=None, args=None): + if args.use_8bit_training != 'off': + weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) + outputq = bnb.functional.igemm(x8, weight8.t()) + output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) + #if torch.rand(1) < 0.01: + #output32 = torch.matmul(x, weight.t()) + #err = torch.abs(output-output32).float() + #relerr = err/(torch.abs(output32).float()+1e-8) + #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) + else: + #output = torch.matmul(x, weight.t()) + output = torch.einsum('bsi,oi->bso', x, weight) + + ctx.save_for_backward(x, weight, bias) + ctx.args = args + + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + args = ctx.args + stochastic = False + grad_input = grad_weight = grad_bias = None + if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) + + # weight and x are already 8bit + # -> transform grad_output to 8-bit + if args.use_8bit_training == 'forward+wgrad': + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = bnb.functional.igemm(grad_output8, x8) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + + #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) + + grad_input = grad_output.matmul(weight) + elif args.use_8bit_training == 'full': + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) + bnb.functional.igemm(grad_output8, x8, out=grad_weight8) + grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + + grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) + weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) + grad_input8 = bnb.functional.igemm(grad_output8, weight8) + grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) + + else: + grad_input = grad_output.matmul(weight) + grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) -@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) -def test_embeddings(embcls): - bnb.optim.GlobalOptimManager.get_instance().initialize() - emb1 = torch.nn.Embedding(100, 512).cuda() - emb2 = embcls(100, 512).cuda() + return grad_input, grad_weight, grad_bias, None - adam1 = bnb.optim.Adam8bit(emb1.parameters()) - adam2 = bnb.optim.Adam8bit(emb2.parameters()) +class Linear8bit(nn.Module): + def __init__(self, input_features, output_features, bias=True, args=None): + super(Linear8bit, self).__init__() + self.input_features = input_features + self.output_features = output_features + self.args = args - batches = torch.randint(1, 100, size=(100, 4, 32)).cuda() + self.weight = nn.Parameter(torch.empty(output_features, input_features)) + if bias: + self.bias = nn.Parameter(torch.empty(output_features)) + else: + self.register_parameter('bias', None) + torch.nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + torch.nn.init.zeros_(self.bias) + + def forward(self, x): + self.args.training = self.training + + return LinearFunction.apply(x, self.weight, self.bias, self.args) + + + +def test_linear8bit(): + l0 = torch.nn.Linear(32, 64).cuda().half() + l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() + l2 = Linear8bit(32, 64, args=get_args()).cuda().half() + l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() + + l0.weight.data = l2.weight.data.clone() + l0.bias.data = l2.bias.data.clone() + + l1.weight.data = l2.weight.data.clone() + l1.bias.data = l2.bias.data.clone() + + l3.weight.data = l2.weight.data.clone() + l3.bias.data = l2.bias.data.clone() + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + t = torch.randn(16, 8, 64, device='cuda').half() + b2 = b1.clone() + b3 = b1.clone() + b0 = b1.clone() + + o0 = l0(b0) + o1 = l1(b1) + o2 = l2(b2) + o3 = l3(b3) + + assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1) + assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1) + + loss0 = torch.nn.functional.mse_loss(o0, t) + loss1 = torch.nn.functional.mse_loss(o1, t) + loss2 = torch.nn.functional.mse_loss(o2, t) + loss3 = torch.nn.functional.mse_loss(o3, t) + + loss0.backward() + loss1.backward() + loss2.backward() + loss3.backward() + + assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + + err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() + err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() + err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() + + assert err1*0.8 < err2 + assert err2*0.8 < err3 + assert err3*0.8 < err1 + + l0.weight.grad = None + l1.weight.grad = None + l2.weight.grad = None + l3.weight.grad = None + l0.bias.grad = None + l1.bias.grad = None + l2.bias.grad = None + l3.bias.grad = None + + +threshold = [0.0, 3.0] +values = threshold +names = ['threshold_{0}'.format(vals) for vals in values] +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_inference(threshold): + l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() + assert l1.weight.device.type == 'cuda' + assert l1.weight.dtype == torch.float16 + + l1.eval() for i in range(100): - batch = batches[i] + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + if i == 1: + assert l1.state.CxB is not None + +def test_linear8bitlt_accumulated_gradient(): + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) + l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) + l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) + l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) + l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) + opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) + + acc_steps = 10 + - embedded1 = emb1(batch) - embedded2 = emb2(batch) + for i in range(10): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + o2 = l2(b1) + loss1 = o1.mean() + loss2 = o2.mean() + loss1.backward() + loss2.backward() + if i == 2: + assert l1[0].state.CxB is not None + assert l1[1].state.CxB is not None - l1 = embedded1.mean() - l2 = embedded2.mean() + if i > 0 and i % acc_steps == 0: + opt1.step() + opt1.zero_grad(True) + opt2.step() + opt2.zero_grad(True) + assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) + # we do this copy because otherwise we have small divergences over time that add up + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + else: + torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) + torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) - l1.backward() - l2.backward() - adam1.step() - adam2.step() +threshold = [0.0, 2.0] +values = threshold +names = ['threshold_{0}'.format(vals) for vals in values] +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_no_fp16_weights(threshold): + l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() + assert l1.weight.dtype == torch.int8 - adam1.zero_grad() - adam2.zero_grad() + l1.eval() + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = l1(b1) + assert o1.dtype == torch.float16 + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 - assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 - assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda') + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == 'cuda' + assert mlp.fc2.weight.device.type == 'cuda' + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') + + for i in range(100): + b1 = torch.randn(16, 8, 32, device='cuda').half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: assert mlp.fc1.state.idx is not None + if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == 'cuda' + assert mlp.fc2.weight.device.type == 'cuda' diff --git a/tests/test_optim.py b/tests/test_optim.py index c80fe51..b173eaa 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,12 +1,9 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. import os import time import shutil import uuid import pytest +import ctypes import torch import bitsandbytes as bnb import bitsandbytes.functional as F @@ -14,7 +11,9 @@ import bitsandbytes.functional as F from os.path import join from itertools import product -import apex +#import apex + +k = 20 def get_temp_dir(): path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) @@ -26,55 +25,47 @@ def rm_path(path): str2optimizers = {} str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) -str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) +#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) -str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) -str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) +#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) +#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) -str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW) -str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) -str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) +#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False)) str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) +#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True)) str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) -str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True)) str2statenames = {} str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] -str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['rmsprop'] = [('square_avg', 'state1')] -str2statenames['adagrad'] = [('sum', 'state1')] str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] -str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] -str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] +optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) if gtype == torch.float32: - atol, rtol = 2e-6, 1e-5 + atol, rtol = 1e-6, 1e-5 else: atol, rtol = 1e-4, 1e-3 - for i in range(50): + for i in range(k): g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) - if i % 10 == 0 and i > 0: + if i % (k//5) == 0 and i > 0: path = get_temp_dir() torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) del bnb_optimizer @@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) @@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype): else: atol, rtol = 1e-4, 1e-3 - original_p2 = p2[mask].clone() - for i in range(50): g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 @@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype): p2.grad = g2 p3.grad = g3 - if i > 30 and i % 10 == 0: - g1.data[mask] = 0.0 - g2.data[mask] = 0.0 - p1.grad = g1 - p2.grad = g2 - original_p1 = p1[mask].clone() - original_p2 = p2[mask].clone() - og_s1 = adam2.state[p2]['state1'][mask].clone() - og_s2 = adam2.state[p2]['state2'][mask].clone() - og_s11 = adam2.state[p1]['state1'][mask].clone() - og_s21 = adam2.state[p1]['state2'][mask].clone() - adam2.step() assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8 - if i > 30 and i % 10 == 0: - torch.testing.assert_allclose(original_p2, p2[mask]) - torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1) - torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2) - assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel() - assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0 - assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0 - - dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] +optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 - bnb_optimizer = str2optimizers[optim_name][1]([p1]) g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 p1.grad = g - for i in range(5000): - if i == 500: + for i in range(k): + if i == k//5: # 100 iterations for burn-in torch.cuda.synchronize() t0 = time.time() @@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): torch.cuda.synchronize() s = time.time()-t0 print('') - params = 4500*4096*4096 + params = (k-k//5)*dim1*dim2 print(optim_name, gtype, s/params) #assert s < 3.9 - -def test_str_betas(): - betas = (0.80, 0.95) - strbetas = '(0.80, 0.95)' - - layer = torch.nn.Linear(10, 10) - - base = bnb.optim.Adam(layer.parameters(), betas=betas) - strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas) - assert base.defaults['betas'][0] == 0.8 - assert base.defaults['betas'][1] == 0.95 - assert strbase.defaults['betas'][0] == 0.8 - assert strbase.defaults['betas'][1] == 0.95 - - -- cgit v1.2.3