summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
-rw-r--r--bitsandbytes/__init__.py3
-rw-r--r--bitsandbytes/autograd/__init__.py0
-rw-r--r--bitsandbytes/autograd/_functions.py307
-rw-r--r--bitsandbytes/cextension.py2
-rw-r--r--bitsandbytes/functional.py869
-rw-r--r--bitsandbytes/nn/__init__.py2
-rw-r--r--bitsandbytes/nn/modules.py124
-rw-r--r--csrc/kernels.cu874
-rw-r--r--csrc/kernels.cuh12
-rw-r--r--csrc/ops.cu406
-rw-r--r--csrc/ops.cuh104
-rw-r--r--csrc/pythonInterface.c127
-rw-r--r--tests/test_autograd.py270
-rw-r--r--tests/test_functional.py1763
-rw-r--r--tests/test_modules.py478
-rw-r--r--tests/test_optim.py87
16 files changed, 5269 insertions, 159 deletions
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 02ca804..3c3affa 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.
from .nn import modules
+from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
from .cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .optim import adam
-__pdoc__ = {'libBitsNBytes': False,
+__pdoc__ = {'libbitsandbytes': False,
'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False
}
diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bitsandbytes/autograd/__init__.py
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
new file mode 100644
index 0000000..815a4f1
--- /dev/null
+++ b/bitsandbytes/autograd/_functions.py
@@ -0,0 +1,307 @@
+import torch
+import bitsandbytes as bnb
+import bitsandbytes.functional as F
+
+from dataclasses import dataclass
+
+tensor = torch.Tensor
+
+'''
+ This class pools outlier dimensions across layers.
+ This is particularly important for small models where outlier features
+ are less systematic and occur with low frequency.
+'''
+class GlobalOutlierPooler(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.outliers = set()
+ self.model_dim = None
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+ def add_outliers(self, outlier_idx, feature_dim):
+ if self.model_dim is None: self.model_dim = feature_dim
+ if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
+
+ self.outliers.update(outlier_idx.tolist())
+
+ def get_current_outlier_idx(self):
+ return torch.Tensor(list(self.outliers)).to(torch.int64)
+
+class MatMul8bit(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
+
+ if precision[0] != 8:
+ with torch.no_grad():
+ output = torch.matmul(A, B)
+ else:
+ if len(B.shape) == 2: dim = 0
+ else: dim = 1
+ qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
+ qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
+ iout = F.igemm(qA, qB)
+ output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
+
+ if A.requires_grad or B.requires_grad:
+ ctx.save_for_backward(A, B)
+
+ ctx.quant_type = quant_type
+ ctx.precision = precision
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ A, B = ctx.saved_tensors
+ quant_type = ctx.quant_type
+ precision = ctx.precision
+ grad_A = grad_B = None
+
+ if B.requires_grad:
+ if len(A.shape) == 3:
+ dims = [0, 1]
+ # bsi -> ibs
+ permute_dim = [0, 2, 1]
+ else:
+ dims = [0]
+ # bs -> sb
+ permute_dim = [1, 0]
+
+ if precision[1] != 8:
+ with torch.no_grad():
+ grad_B = torch.matmul(A.permute(permute_dim), grad_output)
+ else:
+ if len(B.shape) == 2 and len(A.shape) == 3:
+ grad_output = grad_output.contiguous()
+ if not grad_output.is_contiguous(): grad_output.contiguous()
+ qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
+ if not A.is_contiguous(): A = A.contiguous()
+ qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
+ igrad_B = F.igemm(qA.t(), qgrad_output)
+ grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
+ else:
+ qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
+ igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
+ grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
+
+ if A.requires_grad:
+ if len(grad_output.shape) == 3: dims = [2]
+ else: dims = [1]
+
+ if len(B.shape) == 3:
+ # bio -> boi
+ permute_dim = [0, 2, 1]
+ dim_B = dims
+ else:
+ # io -> oi
+ permute_dim = [1, 0]
+ dim_B = [1]
+
+ if precision[2] != 8:
+ with torch.no_grad():
+ grad_A = torch.matmul(grad_output, B.permute(permute_dim))
+ else:
+ qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
+ igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
+ grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
+
+ return grad_A, grad_B, None, None, None
+
+
+mm_cublas = MatMul8bit.apply
+bmm_cublas = MatMul8bit.apply
+matmul_cublas = MatMul8bit.apply
+
+@dataclass
+class MatmulLtState:
+ CB = None
+ CxB = None
+ SB = None
+ SCB = None
+
+ CxBt = None
+ SBt = None
+ CBt = None
+
+ subB = None
+
+ outlier_pool = None
+ has_accumulated_gradients = False
+ threshold = 0.0
+ idx = None
+ is_training = True
+ has_fp16_weights = True
+ use_pool = False
+ formatB = F.get_special_format_str()
+
+ def reset_grads(self):
+ self.CB = None
+ self.CxB = None
+ self.SB = None
+ self.SCB = None
+
+ self.CxBt = None
+ self.SBt = None
+ self.CBt = None
+
+
+class MatMul8bitLt(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # 1. Quantize A
+ # 2. Quantize B
+ # 3. Matmul
+ # 4. Mixed-precision decomposition matmul
+ # 5. Save state
+ requires_gradA = A.requires_grad
+ requires_gradB = B.requires_grad
+ formatB = state.formatB
+ input_shape = A.shape
+ if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
+ assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
+
+ # 1. Quantize A
+ if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
+
+ if state.threshold > 0.0 and coo_tensorA is not None:
+ if state.has_fp16_weights:
+ idx = torch.unique(coo_tensorA.colidx).long()
+ CA[:, idx] = 0
+ CAt[:, idx] = 0
+ subA = A[:, idx]
+ state.subB = B[:, idx].t().contiguous()
+ state.idx = idx
+ else:
+ if state.CxB is None:
+ # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+ # we also need to convert it to the turing/ampere format
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
+ # generate outlier index and subB
+ outlier_idx = torch.unique(coo_tensorA.colidx).long()
+ state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # do not use pool for 2nd FFN layer
+ state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+ else:
+ state.idx = outlier_idx
+ state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
+
+ if state.idx is not None:
+ # extract outliers
+ CA[:, state.idx] = 0
+ CAt[:, state.idx] = 0
+ subA = A[:, state.idx]
+ else:
+ subA = None
+ else:
+ if not state.has_fp16_weights and state.CxB is None:
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ subA = None
+
+ C32A, SA = F.transform(CA, 'col32')
+
+ # 2. Quantize B
+ if state.has_fp16_weights:
+ has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
+ is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
+ if is_transposed: B = B.contiguous()
+
+ if (state.is_training and not has_grad) or state.CxB is None:
+ state.reset_grads()
+ CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
+ state.CxB, state.SB = F.transform(CB, to_order=formatB)
+ else:
+ has_grad = False
+
+ shapeB = state.SB[0]
+
+ if len(input_shape) == 3:
+ output_shape = (input_shape[0], input_shape[1], shapeB[0])
+ else:
+ output_shape = (input_shape[0], shapeB[0])
+
+ # 3. Matmul
+ out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
+
+ # 4. Mixed-precision decomposition matmul
+ if state.threshold > 0.0 and coo_tensorA is not None and subA is not None:
+ output += torch.matmul(subA, state.subB)
+
+ # 5. Save state
+ ctx.state = state
+
+ ctx.formatB = formatB
+ ctx.grad_shape = input_shape
+ ctx.req_grads = [requires_gradA, requires_gradB]
+
+ if requires_gradA or requires_gradB:
+ ctx.tensors = (CAt, subA)
+ ctx.tensor_states = (SCAt, state.idx)
+ else:
+ ctx.tensors = [None, None]
+ ctx.tensor_states = (None, None)
+ ctx.save_for_backward(None, None)
+
+ #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+ clone_func = torch.clone
+ return clone_func(output.view(output_shape))
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ req_gradA, req_gradB = ctx.req_grads
+ CAt, subA = ctx.tensors
+ SCAt, idx = ctx.tensor_states
+ formatB = ctx.formatB
+ state = ctx.state
+ assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
+
+ if len(grad_output.shape) == 3:
+ grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
+
+ grad_A = grad_B = None
+
+ Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
+ if req_gradB:
+ CxAt, SAt = F.transform(CAt, formatB, transpose=True)
+ C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
+ gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
+ grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
+ if state.threshold > 0.0 and subA is not None:
+ grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
+
+ if req_gradA:
+ C32grad, Sgrad = F.transform(Cgrad, 'col32')
+ if state.CxBt is None:
+ state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
+ gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+
+ return grad_A, grad_B, None, None, None, None, None
+
+
+matmul = MatMul8bitLt.apply
+
+
+def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
+ state = state or MatmulLtState()
+ if threshold > 0.0:
+ state.threshold = threshold
+ return MatMul8bitLt.apply(A, B, out, state)
+
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 63d627e..2374c35 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try:
lib.cadam32bit_g32
+ lib.get_context.restype = ct.c_void_p
+ lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ab4e565..806c254 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -36,9 +36,51 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
-optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0]
-optimal_half_normal = [0.0025565922260284424, 0.005811259150505066, 0.00961565226316452, 0.010822802782058716, 0.013123787939548492, 0.014242202043533325, 0.0143156498670578, 0.016469404101371765, 0.017666727304458618, 0.01773911714553833, 0.0199756920337677, 0.0210941880941391, 0.021161124110221863, 0.02451971173286438, 0.024580076336860657, 0.02685210108757019, 0.028012827038764954, 0.030198264867067337, 0.0302925705909729, 0.03136435151100159, 0.03374280035495758, 0.03487399220466614, 0.035243816673755646, 0.037192340940237045, 0.03822284936904907, 0.04164902865886688, 0.04173608124256134, 0.04401407018303871, 0.04508155584335327, 0.047482021152973175, 0.04756556823849678, 0.050963032990694046, 0.05196474492549896, 0.055417388677597046, 0.05793146416544914, 0.05799369141459465, 0.05887940526008606, 0.05895659327507019, 0.062420234084129333, 0.06493274495005608, 0.06499008461833, 0.06935599446296692, 0.07197384163737297, 0.07201516255736351, 0.07276943325996399, 0.07283210754394531, 0.07550075277686119, 0.07975354790687561, 0.07980883121490479, 0.08257630094885826, 0.0867777168750763, 0.08682405948638916, 0.08967285975813866, 0.09323835000395775, 0.09386616945266724, 0.09735457599163055, 0.09739077091217041, 0.10092401504516602, 0.10444298386573792, 0.10447832942008972, 0.10770941898226738, 0.10803905129432678, 0.11161200702190399, 0.1151546835899353, 0.11520349979400635, 0.11875157058238983, 0.11879390478134155, 0.1222602017223835, 0.122351735830307, 0.12240418791770935, 0.12594850733876228, 0.12597402930259705, 0.12602100148797035, 0.12960633635520935, 0.1296597123146057, 0.12966342642903328, 0.13227657973766327, 0.13325360417366028, 0.1333133578300476, 0.13691483438014984, 0.1371927298605442, 0.14066261053085327, 0.14088113978505135, 0.1447291411459446, 0.14805573225021362, 0.148526418954134, 0.15170684456825256, 0.15178103744983673, 0.15225710347294807, 0.1554398238658905, 0.15609459951519966, 0.15618794038891792, 0.1592724472284317, 0.1629735231399536, 0.16382690146565437, 0.16676269471645355, 0.16873238794505596, 0.17066434025764465, 0.17068277299404144, 0.1717144437134266, 0.17558929696679115, 0.17827065289020538, 0.17835864424705505, 0.18222273886203766, 0.18353315070271492, 0.18604370951652527, 0.18611834943294525, 0.1876586265861988, 0.18996606767177582, 0.19170701876282692, 0.19398853182792664, 0.19786442816257477, 0.19795633852481842, 0.20195159316062927, 0.2058800607919693, 0.2099103182554245, 0.2122517265379429, 0.21410366892814636, 0.21819619834423065, 0.22221362590789795, 0.22233009338378906, 0.22500130906701088, 0.2251257635653019, 0.22638091444969177, 0.23067741096019745, 0.23368822410702705, 0.2348879873752594, 0.2382080741226673, 0.2390350103378296, 0.2391497790813446, 0.24253453686833382, 0.24265171959996223, 0.2470107562839985, 0.24764248728752136, 0.24777774512767792, 0.2516774423420429, 0.256104726344347, 0.2564055472612381, 0.2607169933617115, 0.265461727976799, 0.26985861361026764, 0.2701106257736683, 0.2702729292213917, 0.274574413895607, 0.2750340588390827, 0.27919672429561615, 0.283704474568367, 0.28386808931827545, 0.28953738883137703, 0.2896753139793873, 0.29320384562015533, 0.29451676085591316, 0.295327290892601, 0.29802779853343964, 0.29818175733089447, 0.29972871020436287, 0.30290623009204865, 0.30305664241313934, 0.30486901476979256, 0.31299956142902374, 0.31518544629216194, 0.31790371239185333, 0.3205283172428608, 0.3230419009923935, 0.32595496252179146, 0.32612212374806404, 0.3282426446676254, 0.3283906430006027, 0.33146094158291817, 0.3316439874470234, 0.33365286886692047, 0.33723779395222664, 0.3390095978975296, 0.3427443392574787, 0.34853987768292427, 0.34869300201535225, 0.35457711294293404, 0.35537679493427277, 0.3604113645851612, 0.36124424636363983, 0.3665340431034565, 0.36667295172810555, 0.3727492541074753, 0.3729033060371876, 0.37888188660144806, 0.37907837703824043, 0.3792510814964771, 0.38557394221425056, 0.38573457673192024, 0.39108292758464813, 0.39911722019314766, 0.40589402988553047, 0.40604450181126595, 0.410498782992363, 0.4106704741716385, 0.4129834659397602, 0.4131447561085224, 0.4172855168581009, 0.4202354736626148, 0.4204071946442127, 0.43538858368992805, 0.4355536885559559, 0.4432900734245777, 0.44603554904460907, 0.4461968094110489, 0.451409537345171, 0.4598204083740711, 0.46002377942204475, 0.46178819239139557, 0.46868549659848213, 0.46995367109775543, 0.4868385046720505, 0.48702501133084297, 0.4958047419786453, 0.4960057884454727, 0.5051481872797012, 0.506847757846117, 0.5148334950208664, 0.5150565356016159, 0.5174009390175343, 0.5249751061201096, 0.5283288545906544, 0.5355450958013535, 0.539984006434679, 0.5467876642942429, 0.5522958822548389, 0.5584012717008591, 0.5706631988286972, 0.5836620181798935, 0.5836880058050156, 0.5942088551819324, 0.5975865572690964, 0.6102624125778675, 0.6124880760908127, 0.6286389082670212, 0.646102175116539, 0.6471664495766163, 0.665437325835228, 0.6687244363129139, 0.687017485499382, 0.6932839937508106, 0.7115348428487778, 0.7218200154602528, 0.7219699807465076, 0.7747527211904526, 0.7749756425619125, 0.8192005604505539, 0.8194110840559006, 0.8830635994672775, 0.9217727445065975, 0.9245667457580566, 0.947742685675621, 0.9674464613199234, 0.9890814647078514, 0.9891453236341476, 0.9925699159502983]
+class CUBLAS_Context(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.context = {}
+ #prev_device = torch.cuda.current_device()
+ #for i in range(torch.cuda.device_count()):
+ # torch.cuda.set_device(torch.device('cuda', i))
+ # self.context.append(ct.c_void_p(lib.get_context()))
+ #torch.cuda.set_device(prev_device)
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+ def get_context(self, device):
+ if device.index not in self.context:
+ prev_device = torch.cuda.current_device()
+ torch.cuda.set_device(device)
+ self.context[device.index] = ct.c_void_p(lib.get_context())
+ torch.cuda.set_device(prev_device)
+ return self.context[device.index]
+
+class Cusparse_Context(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.context = ct.c_void_p(lib.get_cusparse())
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
def create_linear_map(signed=True):
if signed:
@@ -89,6 +131,16 @@ def create_dynamic_map(signed=True, n=7):
data.sort()
return Tensor(data)
+def get_special_format_str():
+ major, minor = torch.cuda.get_device_capability()
+ if major < 7:
+ print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ assert major >= 7
+
+ if major == 7: return 'col_turing'
+ elif major == 8: return 'col_ampere'
+ else: return 'col_turing'
+
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
@@ -105,6 +157,105 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
if A is None: return None
else: return ct.c_void_p(A.data.storage().data_ptr())
+def pre_call(device):
+ prev_device = torch.cuda.current_device()
+ torch.cuda.set_device(device)
+ return prev_device
+
+def post_call(prev_device):
+ torch.cuda.set_device(prev_device)
+
+def get_transform_func(dtype, orderA, orderOut, transpose=False):
+ name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
+ if not hasattr(lib, name):
+ print(name)
+ raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}')
+ else:
+ return getattr(lib, name)
+
+class GlobalData(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.data = {}
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+
+def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False):
+ #init_func = torch.empty
+ init_func = torch.zeros
+ dims = len(shape)
+
+ if dims == 2:
+ rows = shape[0]
+ elif dims == 3:
+ rows = shape[0]*shape[1]
+ cols = shape[-1]
+
+ state = (shape, to_order)
+ if transpose:
+ # swap dims
+ tmp = rows
+ rows = cols
+ cols = tmp
+ state = (shape[::-1], to_order)
+
+ if to_order == 'row' or to_order == 'col':
+ return init_func(shape, dtype=dtype, device=device), state
+ elif to_order == 'col32':
+ # blocks of 32 columns (padded)
+ cols = 32*((cols+31)//32)
+ return init_func((rows, cols), dtype=dtype, device=device), state
+ elif to_order == 'col_turing':
+ # blocks of 32 columns and 8 rows
+ cols = 32*((cols+31)//32)
+ rows = 8*((rows+7)//8)
+ return init_func((rows, cols), dtype=dtype, device=device), state
+ elif to_order == 'col_ampere':
+ # blocks of 32 columns and 32 rows
+ cols = 32*((cols+31)//32)
+ rows = 32*((rows+31)//32)
+ return init_func((rows, cols), dtype=dtype, device=device), state
+ else:
+ raise NotImplementedError(f'To_order not supported: {to_order}')
+
+def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+ if state is None: state = (A.shape, from_order)
+ else: from_order = state[1]
+ if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
+ else: new_state = (state[1], to_order)
+ func = get_transform_func(A.dtype, from_order, to_order, transpose)
+
+ shape = state[0]
+ if len(shape) == 2:
+ dim1 = ct.c_int32(shape[0])
+ dim2 = ct.c_int32(shape[1])
+ elif ld is not None:
+ n = math.prod(shape)
+ dim1 = math.prod([shape[i] for i in ld])
+ dim2 = ct.c_int32(n//dim1)
+ dim1 = ct.c_int32(dim1)
+ else:
+ dim1 = ct.c_int32(shape[0]*shape[1])
+ dim2 = ct.c_int32(shape[2])
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
+
+
+ return out, new_state
+
def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
@@ -544,3 +695,717 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
+
+def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
+ if not torch.cuda.is_initialized(): torch.cuda.init()
+ if A.dtype != expected_type or B.dtype != expected_type:
+ raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}')
+
+ sA = A.shape
+ sB = B.shape
+ tA = transposed_A
+ tB = transposed_B
+
+ correct = True
+
+ if len(sA) == 2 and len(sB) == 2:
+ if not tA and not tB and A.shape[1] != B.shape[0]: correct = False
+ elif tA and not tB and A.shape[0] != B.shape[0]: correct = False
+ elif tA and tB and A.shape[0] != B.shape[1]: correct = False
+ elif not tA and tB and A.shape[1] != B.shape[1]: correct = False
+ elif len(sA) == 3 and len(sB) == 2:
+ if not tA and not tB and A.shape[2] != B.shape[0]: correct = False
+ elif tA and not tB and A.shape[1] != B.shape[0]: correct = False
+ elif tA and tB and A.shape[1] != B.shape[1]: correct = False
+ elif not tA and tB and A.shape[2] != B.shape[1]: correct = False
+ elif len(sA) == 3 and len(sB) == 3:
+ if not tA and not tB and A.shape[2] != B.shape[1]: correct = False
+ elif tA and not tB and A.shape[1] != B.shape[1]: correct = False
+ elif tA and tB and A.shape[1] != B.shape[2]: correct = False
+ elif not tA and tB and A.shape[2] != B.shape[2]: correct = False
+
+ if out is not None:
+ sout = out.shape
+ # special case common in backprop
+ if not correct and len(sA) == 3 and len(sB) == 3:
+ if (sout[0] == sA[2] and sout[1] == sB[2] and
+ sA[0] == sB[0] and sA[1] == sB[1]):
+ correct = True
+ else:
+ if len(sA) == 2 and len(sB) == 2:
+ if not tA and not tB: sout = (sA[0], sB[1])
+ elif tA and tB: sout = (sA[1], sB[0])
+ elif tA and not tB: sout = (sA[1], sB[1])
+ elif not tA and tB: sout = (sA[0], sB[0])
+ elif len(sA) == 3 and len(sB) == 2:
+ if not tA and not tB: sout = (sA[0], sA[1], sB[1])
+ elif tA and tB: sout = (sA[0], sA[2], sB[0])
+ elif tA and not tB: sout = (sA[0], sA[2], sB[1])
+ elif not tA and tB: sout = (sA[0], sA[1], sB[0])
+ elif len(sA) == 3 and len(sB) == 3:
+ if not tA and not tB: sout = (sA[0], sA[1], sB[2])
+ elif tA and tB: sout = (sA[0], sA[2], sB[1])
+ elif tA and not tB: sout = (sA[0], sA[2], sB[2])
+ elif not tA and tB: sout = (sA[0], sA[1], sB[1])
+
+
+ if not correct:
+ raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.')
+
+ return sout
+
+def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+ sout = check_matmul(A, B, out, transposed_A, transposed_B)
+ if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+ if len(A.shape) == 3 and len(B.shape) == 3:
+ if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
+ return batched_igemm(A, B, out)
+
+ sA = A.shape
+ sB = B.shape
+ if transposed_A and len(sA) == 2: sA = (sA[1], sA[0])
+ elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0])
+ if transposed_B and len(sB) == 2: sB = (sB[1], sB[0])
+ elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0])
+ # this is a mess: cuBLAS expect column major, but PyTorch is row major.
+ # So to perform the matrix multiplication, we have to treat A, B, and C matrices
+ # (transpose of row major is column major)
+ # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
+
+ # matrices in the input arguments for cuBLAS
+ # column major: A @ B = C: [m, k] @ [k, n] = [m, n]
+ # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
+ # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
+ if len(sB) == 2:
+ if B.stride()[0] == B.shape[1]: transposed_B = False
+ elif B.stride()[1] == B.shape[0]: transposed_B = True
+ if len(A.shape) == 2:
+ if A.stride()[0] == A.shape[1]: transposed_A = False
+ elif A.stride()[1] == A.shape[0]: transposed_A = True
+ else:
+ if A.stride()[1] == A.shape[2]: transposed_A = False
+ elif A.stride()[2] == A.shape[1]: transposed_A = True
+
+ if len(sA) == 2:
+ n = sA[0]
+ ldb = A.stride()[1 if transposed_A else 0]
+ elif len(sA) == 3 and len(sB) == 2:
+ n = sA[0]*sA[1]
+ ldb = sA[2]
+
+
+ m = sB[1]
+ k = sB[0]
+ lda = B.stride()[(1 if transposed_B else 0)]
+ ldc = sB[1]
+ elif len(sB) == 3:
+ # special case
+ assert len(sA) == 3
+ if not (sA[0] == sB[0] and sA[1] == sB[1]):
+ raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}')
+
+ transposed_A = True
+ transposed_B = False
+
+ m = sB[2]
+ n = sA[2]
+ k = sB[0]*sB[1]
+
+ lda = m
+ ldb = sA[2]
+ ldc = m
+
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+
+ # B^T @ A^T = C^T
+ # [km, nk -> mn]
+ lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
+ get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
+ return out
+
+
+def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+ if not len(A.shape) == 3 or not len(B.shape) == 3:
+ raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}')
+ sout = check_matmul(A, B, out, transposed_A, transposed_B)
+ if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+
+ if B.is_contiguous():
+ lda = B.stride()[1]
+ transposed_A = False
+ else:
+ s = B.stride()
+ if s[0] != B.shape[0]:
+ B = B.contiguous()
+ lda = B.stride()[1]
+ elif s[2] == B.shape[1]:
+ transposed_A = True
+ lda = B.stride()[2]
+ else:
+ if s[2] == 1:
+ B = B.contiguous()
+ lda = B.stride()[1]
+ elif s[1] == 1:
+ B = B.contiguous()
+ lda = B.stride()[1]
+ else:
+ B = B.contiguous()
+ lda = B.stride()[1]
+
+ if A.is_contiguous():
+ ldb = A.stride()[1]
+ transposed_B = False
+ else:
+ s = A.stride()
+ if s[0] != A.shape[0]:
+ A = A.contiguous()
+ ldb = A.stride()[1]
+ transposed_B = False
+ elif s[2] == A.shape[1]:
+ ldb = A.stride()[2]
+ transposed_B = True
+ else:
+ A = A.contiguous()
+ ldb = A.stride()[1]
+ transposed_B = False
+
+ # this is a mess: cuBLAS expect column major, but PyTorch is row major.
+ # So to perform the matrix multiplication, we have to treat A, B, and C matrices
+ # (transpose of row major is column major)
+ # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
+ # matrices in the input arguments for cuBLAS
+
+ # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
+ # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
+ # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
+ num_batch = A.shape[0]
+ n = A.shape[1]
+ m = B.shape[2]
+ k = B.shape[1]
+
+ ldc = m
+
+ strideA = B.shape[1]*B.shape[2]
+ strideB = A.shape[1]*A.shape[2]
+ strideC = A.shape[1]*B.shape[2]
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+
+ lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
+ get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
+ ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
+ return out
+
+def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
+ shapeA = SA[0]
+ shapeB = SB[0]
+ dimsA = len(shapeA)
+ dimsB = len(shapeB)
+ if dimsA == 2:
+ m = shapeA[0]
+ elif dimsA == 3:
+ m = shapeA[0]*shapeA[1]
+
+ if dimsB == 2:
+ rows = n = shapeB[0]
+ elif dimsB == 3:
+ rows = n = shapeB[0]*shapeB[1]
+
+ if dimsA == 2 and out is None:
+ out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
+ elif dimsA == 3 and out is None:
+ out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
+
+ if row_scale is not None: assert row_scale.numel() == out.shape[0]
+ assert dimsB != 3, 'len(B.shape)==3 not supported'
+ assert A.device.type == 'cuda'
+ assert B.device.type == 'cuda'
+ assert A.dtype == torch.int8
+ assert B.dtype == torch.int8
+ assert out.dtype == dtype
+ assert SA[1] == 'col32'
+ assert SB[1] in ['col_turing', 'col_ampere']
+ assert Sout[1] == 'col32'
+ assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}'
+ formatB = SB[1]
+ prev_device = A.device
+ torch.cuda.set_device(A.device)
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ ptrA = get_ptr(A)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ ptrRowScale = get_ptr(row_scale)
+
+ k = shapeA[-1]
+ lda = ct.c_int32(m*32)
+ if formatB == 'col_turing':
+ # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
+ # n = rows
+ ldb = ct.c_int32(((rows+7)//8)*8*32)
+ else:
+ # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
+ # n = rows
+ ldb = ct.c_int32(((rows+31)//32)*32*32)
+
+ ldc = ct.c_int32(m*32)
+ m = ct.c_int32(m)
+ n = ct.c_int32(n)
+ k = ct.c_int32(k)
+
+ has_error = 0
+ if formatB == 'col_turing':
+ if dtype == torch.int32:
+ has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ elif row_scale is None:
+ has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ else:
+ has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ elif formatB == 'col_ampere':
+ if dtype == torch.int32:
+ has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ elif row_scale is None:
+ has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ else:
+ has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+
+ if has_error == 1:
+ raise Exception('cublasLt ran into an error!')
+
+ torch.cuda.set_device(prev_device)
+
+
+ return out, Sout
+
+
+def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None):
+ assert A.dtype == torch.int32
+ out_shape = quant_state[0]
+ if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2])
+
+ if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
+ if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
+ if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
+ assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
+ assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
+
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ ptrRowStats = get_ptr(row_stats)
+ ptrColStats = get_ptr(col_stats)
+ ptrNewRowStats = get_ptr(new_row_stats)
+ ptrNewColStats = get_ptr(new_col_stats)
+ numRows = ct.c_int32(out_shape[0])
+ numCols = ct.c_int32(out_shape[1])
+
+ lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
+
+ return out
+
+
+def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
+ assert A.dtype == torch.float16
+ device = A.device
+
+ cols = A.shape[-1]
+ if len(A.shape) == 3:
+ rows = A.shape[0]*A.shape[1]
+ else:
+ rows = A.shape[0]
+
+ col_tiles = (cols+255)//256
+ tiled_rows = ((rows+15)//16)*16
+ if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
+ if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
+
+ if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device)
+
+ ptrA = get_ptr(A)
+ ptrRowStats = get_ptr(row_stats)
+ ptrColStats = get_ptr(col_stats)
+ ptrNnzrows = get_ptr(nnz_block_ptr)
+ rows = ct.c_int32(rows)
+ cols = ct.c_int32(cols)
+
+ prev_device = pre_call(A.device)
+ lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
+ post_call(prev_device)
+
+
+ if threshold > 0.0:
+ nnz_block_ptr.cumsum_(0)
+
+
+ return row_stats, col_stats, nnz_block_ptr
+
+class COOSparseTensor(object):
+ def __init__(self, rows, cols, nnz, rowidx, colidx, values):
+ assert rowidx.dtype == torch.int32
+ assert colidx.dtype == torch.int32
+ assert values.dtype == torch.float16
+ assert values.numel() == nnz
+ assert rowidx.numel() == nnz
+ assert colidx.numel() == nnz
+
+ self.rows = rows
+ self.cols = cols
+ self.nnz = nnz
+ self.rowidx = rowidx
+ self.colidx = colidx
+ self.values = values
+
+class CSRSparseTensor(object):
+ def __init__(self, rows, cols, nnz, rowptr, colidx, values):
+ assert rowptr.dtype == torch.int32
+ assert colidx.dtype == torch.int32
+ assert values.dtype == torch.float16
+ assert values.numel() == nnz
+ assert colidx.numel() == nnz
+ assert rowptr.numel() == rows+1
+
+ self.rows = rows
+ self.cols = cols
+ self.nnz = nnz
+ self.rowptr = rowptr
+ self.colidx = colidx
+ self.values = values
+
+class CSCSparseTensor(object):
+ def __init__(self, rows, cols, nnz, colptr, rowidx, values):
+ assert colptr.dtype == torch.int32
+ assert rowidx.dtype == torch.int32
+ assert values.dtype == torch.float16
+ assert values.numel() == nnz
+ assert rowidx.numel() == nnz
+ assert colptr.numel() == cols+1
+
+ self.rows = rows
+ self.cols = cols
+ self.nnz = nnz
+ self.colptr = colptr
+ self.rowidx = rowidx
+ self.values = values
+
+def coo2csr(cooA):
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ values.add_(1)
+ rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device)
+ rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
+ rowptr.cumsum_(0)
+ return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
+
+def coo2csc(cooA):
+ val, col2rowidx = torch.sort(cooA.colidx)
+ rowidx = cooA.rowidx[col2rowidx]
+ values = cooA.values[col2rowidx]
+ colvalues, counts = torch.unique(val, return_counts=True)
+ colvalues.add_(1)
+ colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device)
+ colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
+ colptr.cumsum_(0)
+ return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
+
+def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
+ rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
+ colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
+ values = torch.zeros((nnz,), dtype=dtype, device=device)
+ return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
+
+
+def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
+ device = A.device
+ assert A.dtype == torch.half
+ assert device.type == 'cuda'
+ prev_device = pre_call(A.device)
+
+ cols = A.shape[-1]
+ if len(A.shape) == 3:
+ rows = A.shape[0]*A.shape[1]
+ else:
+ rows = A.shape[0]
+
+ if row_stats is None or col_stats is None:
+ row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
+
+ if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
+ if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
+
+ coo_tensor = None
+ ptrA = get_ptr(A)
+ ptrColStats = get_ptr(col_stats)
+ ptrRowStats = get_ptr(row_stats)
+ ptrOutCol = get_ptr(out_col)
+ ptrOutRow = get_ptr(out_row)
+
+ if threshold > 0.0:
+ nnz = nnz_row_ptr[-1].item()
+ if nnz > 0:
+ coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
+ ptrRowIdx = get_ptr(coo_tensor.rowidx)
+ ptrColIdx = get_ptr(coo_tensor.colidx)
+ ptrVal = get_ptr(coo_tensor.values)
+ ptrRowPtr = get_ptr(nnz_row_ptr)
+
+ lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ val, idx = torch.sort(coo_tensor.rowidx)
+ coo_tensor.rowidx = val
+ coo_tensor.colidx = coo_tensor.colidx[idx]
+ coo_tensor.values = coo_tensor.values[idx]
+ else:
+ lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols))
+ else:
+ lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ post_call(prev_device)
+
+ return out_row, out_col, row_stats, col_stats, coo_tensor
+
+
+def get_special_format_str():
+ major, minor = torch.cuda.get_device_capability()
+ if major < 7:
+ print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ assert major >= 7
+
+ if major == 7: return 'col_turing'
+ elif major == 8: return 'col_ampere'
+ else: return 'col_turing'
+
+
+
+
+def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+ if state is None: state = (A.shape, from_order)
+ else: from_order = state[1]
+ if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
+ else: new_state = (state[0], to_order) # (shape, order)
+
+ shape = state[0]
+ if len(shape) == 2:
+ dim1 = ct.c_int32(shape[0])
+ dim2 = ct.c_int32(shape[1])
+ else:
+ dim1 = ct.c_int32(shape[0]*shape[1])
+ dim2 = ct.c_int32(shape[2])
+
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ if to_order == 'col32':
+ if transpose:
+ lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif to_order == 'col_turing':
+ if transpose:
+ lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif to_order == 'col_ampere':
+ if transpose:
+ lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif to_order == 'row':
+ if from_order == 'col_turing':
+ lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif from_order == 'col_ampere':
+ lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
+
+
+
+
+ return out, new_state
+
+def spmm_coo(cooA, B, out=None):
+ if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
+ nnz = cooA.nnz
+ assert cooA.rowidx.numel() == nnz
+ assert cooA.colidx.numel() == nnz
+ assert cooA.values.numel() == nnz
+ assert cooA.cols == B.shape[0]
+
+ transposed_B = (False if B.is_contiguous() else True)
+
+ ldb = B.stride()[(1 if transposed_B else 0)]
+ ldc = B.shape[1]
+
+ ptr = Cusparse_Context.get_instance().context
+
+ ptrRowidx = get_ptr(cooA.rowidx)
+ ptrColidx = get_ptr(cooA.colidx)
+ ptrValues = get_ptr(cooA.values)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ cnnz = ct.c_int32(cooA.nnz)
+ crowsA = ct.c_int32(cooA.rows)
+ ccolsA = ct.c_int32(cooA.cols)
+ ccolsB = ct.c_int32(B.shape[1])
+ cldb = ct.c_int32(ldb)
+ cldc = ct.c_int32(ldc)
+
+ lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
+
+ return out
+
+def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
+ if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
+ nnz = cooA.nnz
+ assert cooA.rowidx.numel() == nnz
+ assert cooA.colidx.numel() == nnz
+ assert cooA.values.numel() == nnz
+ assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}'
+
+ transposed_B = (False if B.is_contiguous() else True)
+
+ ldb = B.stride()[(1 if transposed_B else 0)]
+ ldc = B.shape[1]
+
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ offset = counts.cumsum(0).int()
+ max_count, max_idx = torch.sort(counts, descending=True)
+ max_idx = max_idx.int()
+ max_count = max_count.int()
+ assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.'
+ assert B.dtype in [torch.float16, torch.int8]
+ ptrOffset = get_ptr(offset)
+ ptrMaxCount = get_ptr(max_count)
+ ptrMaxIdx = get_ptr(max_idx)
+
+ ptrRowidx = get_ptr(cooA.rowidx)
+ ptrColidx = get_ptr(cooA.colidx)
+ ptrValues = get_ptr(cooA.values)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ ptrDequantStats = get_ptr(dequant_stats)
+ cnnz_rows = ct.c_int32(counts.numel())
+ cnnz = ct.c_int32(cooA.nnz)
+ crowsA = ct.c_int32(cooA.rows)
+ ccolsA = ct.c_int32(cooA.cols)
+ crowsB = ct.c_int32(B.shape[1])
+ ccolsB = ct.c_int32(B.shape[1])
+ cldb = ct.c_int32(ldb)
+ cldc = ct.c_int32(ldc)
+ #print(cooA.rowidx[:64])
+ #print(cooA.colidx[:64].sort()[0])
+
+ if B.dtype == torch.float16:
+ lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
+ elif B.dtype == torch.int8:
+ lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
+ #else: assertion error
+
+ return out
+
+
+C = 127.0
+
+def vectorwise_quant(x, dim=1, quant_type='vector'):
+ if quant_type == 'linear':
+ max1 = torch.abs(x).max().float()
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type in ['vector', 'row']:
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ xq = torch.round(x*(C/max1)).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'zeropoint':
+ dtype = x.dtype
+ x = x.float()
+ dyna = x.max() - x.min()
+ if dyna == 0: dyna = 1
+ qx = 255./dyna
+ minx = x.min()
+ zpx = torch.round(minx* qx)
+ x = torch.round(qx*x - zpx) + zpx
+ return x, qx
+ elif quant_type in ['vector-zeropoint', 'row-zeropoint']:
+ dtype = x.dtype
+ x = x.float()
+ dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True))
+ dyna[dyna==0] = 1
+ qx = 255./dyna
+ minx = torch.amin(x, dim=dim, keepdim=True)
+ zpx = torch.round(minx* qx)
+ x = torch.round(qx*x - zpx) + zpx
+ return x, qx
+ elif quant_type == 'truncated-vector':
+ with torch.no_grad():
+ absx = torch.abs(x)
+ max1 = torch.amax(absx, dim=dim, keepdim=True)
+ max1 = max1*0.7
+ idx = (absx > max1.expand_as(absx))
+ sign = torch.sign(x[idx])
+ x[idx] = max1.expand_as(absx)[idx]*sign
+ xq = torch.round(x/max1*C).to(torch.int8)
+ return xq, max1
+ else: return None
+
+def vectorwise_dequant(xq, max1, quant_type='vector'):
+ if quant_type == 'vector':
+ x = (xq/C*max1).to(torch.float32)
+ return x
+ else: return None
+
+def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'):
+ if quant_type == 'linear':
+ norm = S1*S2/(C*C)
+ # double cast needed to prevent overflows
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'zeropoint':
+ norm = 1.0/(S1*S2)
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'row-zeropoint':
+ norm = 1.0/(S1*S2)
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= norm
+ else:
+ x *= norm
+ return x.to(dtype)
+ elif quant_type == 'vector-zeropoint':
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= 1.0/S1
+ else:
+ x *= 1.0/S1
+ x *= 1.0/S2.t()
+ return x.to(dtype)
+ elif quant_type == 'row':
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= S1*S2/(C*C)
+ else:
+ x *= S1*S2/(C*C)
+ return x.to(dtype)
+ elif quant_type in ['truncated-vector', 'vector']:
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= S1/C
+ else:
+ x *= S1/C
+ x *= S2/C
+ return x.to(dtype)
+ else: return None
+
+
+def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
+ offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ x = xq.float()
+ if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(SB.shape) == 2:
+ x *= SB.t()/127
+ else:
+ x *= SB/127
+ x *= SA[1]/127
+ x +=offset
+ return x.to(dtype)
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 27ad6ca..03b4655 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding, Embedding
+from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index c5460fb..5013d0b 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -3,14 +3,19 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
+import bitsandbytes as bnb
-from typing import Optional
+from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
-from torch import Tensor
+from torch import Tensor, device, dtype
+from torch import nn
+from torch.nn.parameter import Parameter
import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager
+T = TypeVar('T', bound='torch.nn.Module')
+
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
@@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
self.norm_type, self.scale_grad_by_freq, self.sparse)
return emb
+
+class Int8Params(torch.nn.Parameter):
+ def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
+ cls.has_fp16_weights = has_fp16_weights
+ cls.CB = None
+ cls.SCB = None
+ if data is None:
+ data = torch.empty(0)
+ return torch.Tensor._make_subclass(cls, data, requires_grad)
+
+ def cuda(self, device):
+ if self.has_fp16_weights:
+ return super().cuda(device)
+ else:
+ # we store the 8-bit rows-major weight
+ # we convert this weight to the turning/ampere weight during the first inference pass
+ B = self.data.contiguous().half().cuda(device)
+ CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
+ del CBt
+ del SCBt
+ self.data = CB
+ setattr(self, 'CB', CB)
+ setattr(self, 'SCB', SCB)
+
+ return self
+
+ @overload
+ def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
+ non_blocking: bool = ...) -> T:
+ ...
+
+ @overload
+ def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
+ ...
+
+ @overload
+ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
+ ...
+
+ def to(self, *args, **kwargs):
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
+
+ if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
+ else:
+ new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
+ new_param.CB = self.CB
+ new_param.SCB = self.SCB
+
+ return new_param
+
+
+
+class Linear8bitLt(nn.Linear):
+ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
+ super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+ self.state = bnb.MatmulLtState()
+ self.index=index
+
+ self.state.threshold = threshold
+ self.state.has_fp16_weights = has_fp16_weights
+ if threshold > 0.0 and not has_fp16_weights:
+ self.state.use_pool = True
+
+ self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+
+ def init_8bit_state(self):
+ self.state.CB = self.weight.CB
+ self.state.SCB = self.weight.SCB
+ self.weight.CB = None
+ self.weight.SCB = None
+
+ def forward(self, x):
+ self.state.is_training = self.training
+
+ if self.weight.CB is not None: self.init_8bit_state()
+ #assert not self.state.has_fp16_weights
+ #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
+
+ out = bnb.matmul(x, self.weight, state=self.state)
+
+ if self.bias is not None:
+ out += self.bias.unsqueeze(0).expand_as(out)
+
+ if not self.state.has_fp16_weights and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+
+ return out
+
+class Linear8bit(nn.Linear):
+ def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
+ super(Linear8bit, self).__init__(input_features, output_features, bias)
+ self.quant_type = quant_type
+ self.index = index
+ self.args = args
+ self.iter = 0
+
+ def forward(self, x):
+ self.iter += 1
+ if self.iter % self.args.clip_freq == 0:
+ with torch.no_grad():
+ maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ print('clip', maxval[-1].item())
+ self.weight.clip_(-maxval[-1], maxval[-1])
+
+
+ if self.args is not None:
+ out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
+ else:
+ out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
+
+ return out
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index d0aabff..1c3e723 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
}
}
+template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
+{
+ // 0. reset stats to -FLT_MAX
+ // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
+ // 2. compute col max (per thread); store in smem due to register pressure
+ // 3. compute row max (per block); store in smem to accumulate full global mem transation
+ // 4. store data via atomicMax
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+ const int items_per_load = ITEMS_PER_THREAD*THREADS;
+
+ typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
+ typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
+ typedef cub::BlockReduce<int, THREADS> BlockRowSum;
+ typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
+
+ __shared__ union {
+ typename BlockExchange::TempStorage exchange;
+ typename BlockRowReduce::TempStorage rowreduce;
+ typename BlockRowSum::TempStorage rowsum;
+ typename LoadT::TempStorage loadt;
+ } temp_storage;
+
+ __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
+ __shared__ int smem_row_nnz_values[TILE_ROWS];
+ //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
+
+ half local_data[ITEMS_PER_THREAD];
+ float local_data_fp32[ITEMS_PER_THREAD];
+ float local_col_absmax_values[ITEMS_PER_THREAD];
+ int local_row_nnz_count = 0;
+ float row_absmax = -FLT_MAX;
+
+ // 0. reset stats to -FLT_MAX
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
+ smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
+ smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
+ }
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_col_absmax_values[j] = -FLT_MAX;
+
+ __syncthreads();
+
+ int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
+ int i = base_idx;
+ // we load row after row from the base_position
+ // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
+ for(int row = 0; row < TILE_ROWS; row++)
+ {
+ if(base_row+row >= rows){ break; }
+ local_row_nnz_count = 0;
+ i = base_idx + ((row)*cols);
+ // each thread gets data from the same column
+ __syncthreads();
+ LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data[j] = fabsf(local_data[j]);
+
+
+ if(SPARSE_DECOMP)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ if((float)local_data[j] >= nnz_threshold)
+ {
+ local_row_nnz_count += 1;
+ local_data[j] = 0.0f;
+ }
+ }
+
+ // 2. compute col max (per thread); store in smem due to register pressure
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ // take the col max for this row
+ // we use shared memory because register pressure is too high if we do this locally
+ //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
+ local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
+
+ // 3. compute row max (per block); store in smem to accumulate full global mem transation
+ __syncthreads();
+
+ // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data_fp32[j] = local_data[j];
+
+ row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
+ if(SPARSE_DECOMP)
+ {
+ __syncthreads();
+ local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
+ }
+ // we store the data temporarily in shared memory so we
+ // can execute a full atomic block transaction into global memory later
+ // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
+ if(threadIdx.x == 0)
+ {
+ smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
+ // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
+ smem_row_nnz_values[row] = local_row_nnz_count;
+ }
+
+ __syncthreads();
+
+ }
+
+ // 4. store data via atomicMax
+ // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
+ // into a striped arangement: [0, 8, 16, 24, ..] for t0
+ __syncthreads();
+ BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_col+threadIdx.x+(j*THREADS) < cols)
+ {
+ float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
+ if(val < local_col_absmax_values[j])
+ atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
+ }
+
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_row+threadIdx.x+(j*THREADS) < rows)
+ {
+ float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
+ if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
+ atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
+ }
+
+ if(SPARSE_DECOMP)
+ if(threadIdx.x < TILE_ROWS)
+ nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
+
+}
+
+template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+
+#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
+
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
+{
+
+ // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
+ // since different row/col stats need to be loaded with each thread.
+ // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
+ // and would lead to low global load utilization.
+ // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
+ // for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
+ // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
+ // This allows for efficient row/col loading from shared memory within the tile.
+ // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
+ // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
+ // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
+ // shared memory loads.
+
+ // data is in 32 column-tile major with tile width 32 columns and numRows rows
+ // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
+ // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
+ // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
+ // C2. Compute normalization values and store col values in register
+ // S1. Store C1 into 16-bit output
+ // S2. Store col/row statistics of new buffer in shared memory
+
+ // We allow for sub-tiles to span multiple col32 tiles. This is okay
+ // since the items per thread only rely on a single column statistic.
+
+
+ const int n_out = numRows*numCols;
+
+ int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
+ // we have tiles of size numRows*32, thus col only increases every numRows
+ // num_row_tiles is the tiles after which the column increases by 32
+ // blockIdx.x is the index of the current tile
+ int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
+ // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
+ int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
+
+ // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
+ // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
+ // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
+ // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
+ // 1024*1024/(128*32) = 256 tiles
+ // 256 tiles are 256*128*32/4 = 256*1024 threads
+
+ // 1. Figure out how index relates to the start of the sub-tile
+ // 2. Each thread < SUBTILE_ROWS calculates row index
+ // 3. Load striped and store in shared memory
+
+ int local_values[ITEMS_PER_THREAD];
+ half local_output[ITEMS_PER_THREAD];
+ float local_rowStats[ITEMS_PER_THREAD];
+ __shared__ float smem_rowStats[SUBTILE_ROWS];
+
+ typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
+ typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
+ __shared__ typename LoadInt32::TempStorage loadint32;
+ __shared__ typename ExchangeInt32::TempStorage exchangeint32;
+
+
+ // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
+ float colStat = col >= numCols ? 0.0f : colStats[col];
+ // no block loads for rows for now -- keep it simple
+ for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
+ {
+ // todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
+ int row = (base_row+j) % numRows; // wrap around
+ // each warp accesses the same element, for four consequitive elements
+ // todo: update description about striped shared memory, it is not needed
+ // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
+ smem_rowStats[j] = rowStats[row];
+ }
+ __syncthreads();
+
+
+ // each block processes SUBTILE_ROWS*32 elements
+ const int items_per_load = THREADS*ITEMS_PER_THREAD;
+ const int rows_per_load = items_per_load/32;
+
+ int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
+ int row_offset = 0;
+ // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
+ int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
+ for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
+ {
+ int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
+ int valid_items = valid_rows*32;
+ if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
+ break;
+
+ // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
+ LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
+ ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
+
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat);
+ //absmax_col = fmax(fabsf(local_output[j]), absmax_col);
+
+ // we store data in row major
+ // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
+ // so that each thread holds ITEMS_PER_THREAD consecutive items for each row
+ // this way throughput into storage is increased by a factor of ~2x
+ // for now we use a simple store
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
+ if(outIdx< n_out && col < numCols)
+ out[outIdx] = local_output[j];
+ }
+
+ row_offset += rows_per_load;
+ }
+}
+
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
+{
+ // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
+ // Each thread reads the same column but multiple rows
+ // Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
+
+ // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
+ // 1. Load data row by row (should be at least with TILE_SIZE = 512)
+ // 2. quantize data with row/col stats
+ // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+ const int items_per_load = ITEMS_PER_THREAD*THREADS;
+
+ typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
+ __shared__ typename LoadHalf::TempStorage loadhalf;
+ typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
+ __shared__ typename StoreInt8::TempStorage storeint8;
+
+ __shared__ float smem_row_stats[TILE_ROWS];
+ __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
+
+ half local_data[ITEMS_PER_THREAD];
+ float local_col_stats[ITEMS_PER_THREAD];
+ char local_quantized_data[ITEMS_PER_THREAD];
+
+ // 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
+ local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
+
+ for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
+ {
+ if(base_row + i < rows)
+ smem_row_stats[i] = rowStats[base_row+i];
+
+ if(SPARSE_DECOMP)
+ smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
+ }
+ __syncthreads();
+
+ // we load row after row from the base_position
+ // 1. Load data row by row (should be at least with TILE_SIZE = 512)
+ for(int row = 0; row < TILE_ROWS; row++)
+ {
+ if(base_row + row >= rows){ break; }
+ int i = base_idx + (row*cols);
+ int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
+
+
+ LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
+ float row_stat = __fdividef(127.0f, smem_row_stats[row]);
+
+ // 2. quantize data with row/col stats
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ // we already pre-normalized the col/row stat:
+ // what this does is float/absmax*127 = int8
+ if(SPARSE_DECOMP)
+ {
+ if(fabsf((float)local_data[j]) >= threshold)
+ {
+ local_quantized_data[j] = 0;
+
+ int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
+
+ rowidx[old_idx] = base_row+row;
+ colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
+ val[old_idx] = local_data[j];
+ }
+ else
+ {
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
+ }
+ }
+ else
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
+ }
+
+ StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
+
+ // 2. quantize data with row/col stats
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ // we already pre-normalized the col/row stat:
+ // what this does is float/absmax*127 = int8
+ local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
+ }
+
+ __syncthreads();
+ StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
+
+ }
+}
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
+{
+
+ // 0. Load data into 32*32 shared memory tiles
+ // 1. transpose / reorder in shared memory
+ // 2. store
+
+ // COL32 FORMAT:
+ // rows*32 tiles
+
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+
+ // AMPERE FORMAT:
+ // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
+ // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
+
+
+ // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
+ // As such we need:
+ // at least 32*4 shared memory tiles for col32; preferably 32*32
+ // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
+ // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
+ // for efficient loading of row major we need to load 128 elements and repeat this 32 items
+ // this would imply a 32x128 shared memory tile -> 4kb
+ // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
+ // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
+ // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
+ // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
+ //
+ // to make the shared memory work with that occupancy we might need to union the block loads/stores
+
+ // each block loads TILE_COLs columns and TILE_ROW rows
+ // after reading a tile the row counter increase by TILE_ROWS
+ // the col counter reset after reading TILE_COL elements
+ const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
+ // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
+ const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
+ const int base_idx = (base_row*cols) + base_col;
+
+ // we load 128 bytes per warp with
+ // 32 rows for transposes that fill col32 types
+ // so that we can have contiguous stores
+ __shared__ char smem_data[32*33*ITEMS_PER_THREAD];
+ char local_data[ITEMS_PER_THREAD];
+ typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
+ __shared__ typename BlockExchange::TempStorage temp_storage;
+
+ // we load row after row from the base_position
+ // Load data row by row
+ int warps = blockDim.x/32;
+ int warp_id = threadIdx.x/32;
+ int warp_lane = threadIdx.x % 32;
+ int offset = 0;
+
+ int smem_row = 0;
+ // each warp loads one row of 128 bytes
+ for(int row = warp_id; row < TILE_ROWS; row+=warps)
+ {
+ int i = base_idx + (row*cols);
+ // we load up to 128 bytes/items per load
+ int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
+
+ // 0. Load data into 32*32 shared memory tiles
+ if(base_row + row < rows)
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int col_idx = warp_lane+(j*32);
+ if(col_idx < valid_items)
+ local_data[j] = A[i+col_idx];
+ else
+ local_data[j] = 0;
+ }
+ }
+ else
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ local_data[j] = 0;
+ }
+
+ if(TRANSPOSE)
+ {
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+ int local_col = (32*j)+warp_lane;
+ //int local_row = row;
+ // store as 256x32
+ smem_data[(local_col*33) + row] = local_data[j];
+ }
+ }
+ else
+ {
+ // treat smem as 32x256, that is 32 rows and 256 columns
+ #pragma unroll ITEMS_PER_THREAD
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
+ }
+
+
+
+ smem_row += warps;
+
+ // 1. transpose / reorder in shared memory
+ if(smem_row % 32 == 0)
+ {
+ smem_row = 0;
+ __syncthreads();
+
+ for(int subrow = warp_id; subrow < 32; subrow+=warps)
+ {
+ for(int j = 0; j < ITEMS_PER_THREAD; j++)
+ {
+
+ switch(FORMAT)
+ {
+ case COL32:
+ if(TRANSPOSE)
+ {
+ // data lies in shared memory in the following way:
+ // row0 [col0 col1 ... col31]
+ // row1 [col0 col1 ... col31]
+ // ...
+ //
+ // As such we read consequtive entries with 256 threads (8rows x 32 columns)
+ // as j increase, the row increase by a factor of 8
+ // We load 8 rows per subrow loop, and subrow increase by 8 per loop
+ // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size outRows*32 and base_row is done in increments of 32
+ offset = base_row*outRows;
+ out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ offset = (base_col/32)*(32*rows);
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+ out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
+ }
+ }
+ break;
+ case COL_TURING:
+ // TURING FORMAT:
+ // 8*32 tiles with 4*4 subtiles
+ // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
+ // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
+ // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
+ // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
+ // index increases by 32
+ //
+ // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
+ if(TRANSPOSE)
+ {
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size 8*32 = 256 elements offset
+ // for each row offset of 8 we increaes the tile first
+ // after all rows are exhausted, we increase the col
+ int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
+
+ // we increase by row_tile_column every 32 columns
+ // base_row increase in increments of 32
+ //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
+ //int col_offset = (base_row/32)*row_tile_column;
+ // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
+ // 256*outRows/8*base_row/32 = outRows*base_row
+ int col_offset = outRows*base_row;
+
+ offset = row_offset+col_offset;
+
+ // since we process even number of rows with each j (8) and with each subrow (8j) we can determine
+ // odd or even rows with the warp_id (each warp processes one row)
+ // the col is warp_lane (max 32 columns per row) and the row warp_id
+ if(warp_id % 2 == 1)
+ // odd
+ offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
+ else
+ // even
+ offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
+
+ out[offset] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+ // set offset designates the tile offset among the 8*32 tiles
+ // we first increase rows and then columns. Since we load 128 columns at once
+ // we increase the offset by outRows*32 every 32 columns
+ // additionally, we increase the offset by 8*32=256 every 8 rows
+ offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
+ // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
+ // each of these has 32 values in total for 32*4 = 128 as offset if odd
+ // every set of 4 columns increases the total offset by 16
+ // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
+ // this happends every 8 rows anew (subrow % 8)
+ // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
+ int subcol = warp_lane;
+
+ // add local offset (4x4 sub-tile)
+ if(subrow % 2 == 1)
+ // odd
+ offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
+ else
+ // even
+ offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
+
+ out[offset] = data;
+ }
+ }
+ break;
+ case COL_AMPERE:
+ // AMPERE FORMAT:
+ // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
+ // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
+ if(TRANSPOSE)
+ {
+ const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
+ const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
+ //const int local_row = warp_id; // each warp_id is one row
+ //const int block_row = base_col; // block offset for row
+ //const int local_col = warp_lane
+ //const int global_col = base_row; // block offset for col
+ if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
+ {
+ // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
+ char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
+
+ // each 32 columns we have new tile
+ // each tile has size 32*32 = 1024 elements offset
+ // for each row offset of 32 we increaes the tile first
+ // after all rows are exhausted, we increase the col
+ int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
+
+ // we increase by row_tile_column every 32 columns
+ // base_row increase in increments of 32
+ //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
+ //int col_offset = (base_row/32)*row_tile_column;
+ // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
+ // 1024*outRows/32*base_row/32 = outRows*base_row
+ int col_offset = outRows*base_row;
+
+ offset = row_offset+col_offset;
+
+
+ // same as in the non-transpose case (see below)
+ // the difference is that now rows = cols
+ // in this case warp_id = subrow
+
+ // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
+ // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
+ // every 2 rows, the offset increases by two [0, 1, 8, 9...]
+ // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
+ int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
+ int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
+
+ // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
+ out[offset + (ampere_row*32) + warp_lane] = data;
+ }
+ }
+ else
+ {
+ if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
+ {
+ char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
+
+ // set offset designates the tile offset among the 32*32 tiles
+ // we first increase rows and then columns. Since we load 128 columns at once
+ // we increase the offset by outRows*32 every 32 columns
+ // additionally, we increase the offset by 32*32=1024 every 32 rows
+ offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
+
+ // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
+ // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
+ // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
+ // every 2 rows, the offset increases by two [0, 1, 8, 9...]
+ // every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
+ int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
+
+ // global offset + row with 32 cols each + 32 cols per j + col_idx
+ out[offset + (local_row*32) + warp_lane] = data;
+ }
+ }
+ break;
+ }
+ }
+ }
+ }
+ }
+}
+
+#define C 1.0f/127.0f
+#define MAX_SPARSE_COUNT 32
+#define SMEM_SIZE 8*256
+template <typename T, int SPMM_ITEMS, int BITS>
+__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
+{
+
+ // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
+ // If a block finishes, the next one is scheduled. Since the last blocks like have fewer
+ // elements they finish faster "fillin up" the gaps left by larger blocks
+
+ // without tensor cores
+ // 1. use rowidx_length to find what to load (as many blocks as there are rows)
+ // 2. Load A into registers
+ // 3. each warp loads all required rows of B but each warp is offset by k
+ // 4. Do mma operations that accumulate into registers
+ // 5. Each warp stores its output row into matrix C
+
+ const int count = max_count[blockIdx.x];
+ const int local_max_idx = max_idx[blockIdx.x];
+ const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
+ const int local_row_idx = rowidx[offset];
+
+ const int warp_id = threadIdx.x / 32;
+ const int warp_idx = threadIdx.x % 32;
+ const int warp_offset = (warp_id*32)*SPMM_ITEMS;
+ const int num_items = BITS == 8 ? 8 : 8;
+ int idx_col_B = warp_offset;
+ int local_idx_col_B_offset = 0;
+
+ half local_valA[MAX_SPARSE_COUNT];
+ int local_colidxA[MAX_SPARSE_COUNT];
+ half local_valC[SPMM_ITEMS];
+ T local_valsB[num_items];
+ half local_valOut[num_items];
+ // 128 byte loads per warp == 4 bytes per thread
+
+ // 2. Load A into registers
+ for(int j = 0; j < MAX_SPARSE_COUNT; j++)
+ {
+ local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
+ local_colidxA[j] = j < count ? colidx[offset+j] : 0;
+ }
+
+ // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
+ // we expect each warp to be SPMM_ITEMS*32 apart
+ // we have a total of 128 bytes for the bank with a bank size of 4 bytes
+ // added 3 bytes = 6 values between warps should reduce bank conflicts
+ __shared__ half smem_dequant_stats[SMEM_SIZE];
+
+
+ while(idx_col_B < colsB)
+ {
+
+ if(dequant_stats != NULL)
+ {
+ for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
+ if((idx_col_B+i-local_idx_col_B_offset) < colsB)
+ smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
+
+ __syncthreads();
+ }
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j++)
+ local_valC[j] = 0.0f;
+
+ #pragma unroll
+ for(int i = 0; i < count; i++)
+ {
+ // 3. each warp loads all required rows of B but each warp is offset by k
+ int row_offset = colsB*local_colidxA[i];
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j+=num_items)
+ {
+ // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
+ int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
+ if(idx >= colsB){ break; }
+ //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
+ if((idx+num_items < colsB))
+ {
+ if(BITS == 8)
+ reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
+ else
+ reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
+ }
+ else
+ {
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ if(idx+k < colsB)
+ local_valsB[k] = B[row_offset+idx+k];
+ else
+ local_valsB[k] = 0.0f;
+ }
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ {
+ //if((float)local_valsB[k] != 0.0)
+ // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
+ if(BITS == 8 && dequant_stats != NULL)
+ // we do texture cache reads (__ldg) on dequant_stats which should be super fast
+ {
+ float valB = local_valsB[k];
+ float valA = local_valA[i];
+ if(valB != 0.0 && valA != 0.0)
+ local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
+ }
+ else
+ local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
+ }
+ }
+ }
+
+ int idx_row_C = (colsB*local_row_idx);
+
+ #pragma unroll SPMM_ITEMS
+ for(int j = 0; j < SPMM_ITEMS; j+=num_items)
+ {
+ //int idx_col_C = idx_col_B + (32*j) + warp_idx;
+ int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
+ int idx_val = idx_col_C + idx_row_C;
+
+ if(idx_col_C +num_items < colsB)
+ {
+
+ // load outputs to do inplace addition
+ reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
+
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
+
+ reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
+ }
+ else
+ {
+ #pragma unroll num_items
+ for(int k = 0; k < num_items; k++)
+ if(idx_col_C + k < colsB)
+ out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
+ }
+ }
+
+ idx_col_B += blockDim.x*SPMM_ITEMS;
+ local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
+ }
+}
+
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+
+template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+
+template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh
index 0a3676c..cbfbeba 100644
--- a/csrc/kernels.cuh
+++ b/csrc/kernels.cuh
@@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
+
+template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
+
+template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
+ int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
+ half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
+
+template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
+
+template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
+
#endif
diff --git a/csrc/ops.cu b/csrc/ops.cu
index 40c185c..8946015 100644
--- a/csrc/ops.cu
+++ b/csrc/ops.cu
@@ -8,6 +8,7 @@
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
+#include <cassert>
#include <common.h>
@@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
+void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
+{
+ const int falpha = 1;
+ const int fbeta = 0;
+ const void * alpha = &falpha;
+ const void * beta = &fbeta;
+ cublasStatus_t status;
+
+ status = cublasGemmEx(context->m_handle,
+ transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
+ transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
+ m, n, k,
+ alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
+ C, CUDA_R_32I, ldc,
+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
+
+ if (status != CUBLAS_STATUS_SUCCESS)
+ {
+ std::cout << "CUBLAS ERROR: Status " << status << std::endl;
+ }
+
+}
+
+void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long long int strideA, long long int strideB, long long int strideC, int batchCount)
+{
+ const int falpha = 1;
+ const int fbeta = 0;
+ const void * alpha = &falpha;
+ const void * beta = &fbeta;
+ cublasStatus_t status;
+
+ //cout << transposeA << transposeB << endl;
+ //printf("%i %i %i\n", m,n,k);
+ //printf("%i %i %i\n", lda,ldb,ldc);
+ //printf("%i %i %i\n", strideA, strideB, strideC);
+ //printf("%i\n", batchCount);
+
+ status = cublasGemmStridedBatchedEx(context->m_handle,
+ transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
+ transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
+ m, n, k,
+ alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
+ C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
+ CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
+
+ if (status != CUBLAS_STATUS_SUCCESS)
+ {
+ std::cout << "CUBLAS ERROR: Status " << status << std::endl;
+ }
+
+}
+
+int roundoff(int v, int d) {
+ return (v + d - 1) / d * d;
+}
+
+
+template<int ORDER> cublasLtOrder_t get_order()
+{
+ switch(ORDER)
+ {
+ case ROW:
+ return CUBLASLT_ORDER_ROW;
+ break;
+ case COL:
+ return CUBLASLT_ORDER_COL;
+ break;
+ case COL32:
+ return CUBLASLT_ORDER_COL32;
+ break;
+ case COL_TURING:
+ return CUBLASLT_ORDER_COL4_4R2_8C;
+ break;
+ case COL_AMPERE:
+ return CUBLASLT_ORDER_COL32_2R_4R4;
+ break;
+ }
+}
+
+template cublasLtOrder_t get_order<ROW>();
+template cublasLtOrder_t get_order<COL>();
+template cublasLtOrder_t get_order<COL32>();
+template cublasLtOrder_t get_order<COL_TURING>();
+template cublasLtOrder_t get_order<COL_AMPERE>();
+
+
+template<int ORDER> int get_leading_dim(int dim1, int dim2)
+{
+ switch(ORDER)
+ {
+ case ROW:
+ return dim2;
+ break;
+ case COL:
+ return dim1;
+ break;
+ case COL32:
+ // 32*row tiles
+ return dim1*32;
+ break;
+ case COL_TURING:
+ return 32*roundoff(dim1, 8);
+ break;
+ case COL_AMPERE:
+ // 32*32 tiles
+ return 32*roundoff(dim1, 32);
+ break;
+ }
+}
+
+template int get_leading_dim<ROW>(int dim1, int dim2);
+template int get_leading_dim<COL>(int dim1, int dim2);
+template int get_leading_dim<COL32>(int dim1, int dim2);
+
+template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
+{
+
+ cublasLtOrder_t orderA = get_order<SRC>();
+ cublasLtOrder_t orderOut = get_order<TARGET>();
+ int ldA = get_leading_dim<SRC>(dim1, dim2);
+ int ldOut = get_leading_dim<TARGET>(dim1, dim2);
+
+ cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
+ cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
+ cublasOperation_t opTranspose = CUBLAS_OP_T;
+ float transformAlpha = 1.0f, transformBeta = 0.0f;
+
+
+ if(DTYPE == 8)
+ {
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
+ }
+ else if(DTYPE == 32)
+ {
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
+ checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
+ }
+ else
+ {
+ printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
+ }
+
+ checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
+ checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
+
+ checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
+
+ if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
+
+ checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
+
+ if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
+ if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
+ if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
+}
+
+template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
+template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
+
+template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+{
+ int has_error = 0;
+ cublasLtMatmulDesc_t matmulDesc = NULL;
+ cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
+ cublasOperation_t opT = CUBLAS_OP_T;
+ cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
+ cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
+ cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
+ cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
+
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(FORMATB == COL_TURING)
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
+ else
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
+
+ if(DTYPE_OUT == 32)
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ int alpha = 1, beta = 0;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
+ has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
+ if(!SCALE_ROWS)
+ {
+ float alpha = 1.0f, beta = 0.0f;
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ else
+ {
+ has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
+ has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
+ }
+ }
+
+
+ if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
+ if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
+ if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
+ if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
+ if(has_error == 1)
+ printf("error detected");
+
+ return has_error;
+}
+
+int fill_up_to_nearest_multiple(int value, int multiple)
+{
+ return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
+}
+
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+{
+ int threads = 512;
+ int tileCols = fill_up_to_nearest_multiple(numCols, 32);
+ int n = numRows*tileCols;
+ int subtile_rows = 128;
+ int tilesize = 32*subtile_rows;
+ int num_blocks = numRows/subtile_rows;
+ num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
+ num_blocks = num_blocks*(tileCols/32);
+ assert(threads <= tilesize);
+
+ //cout << num_blocks << " blocks" << endl;
+
+ kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+#define STATS_THREADS 64
+#define STATS_ITEMS 4
+#define STATS_ROWS 16
+void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
+{
+ int tile_cols = STATS_THREADS*STATS_ITEMS;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
+
+ if(nnz_threshold == 0.0)
+ kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
+ else if(nnz_threshold != 0.0)
+ kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+
+}
+
+void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
+{
+ int threads = 64;
+ int items_per_thread = 4;
+ int tile_cols = threads*items_per_thread;
+ int tile_rows = 16;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+
+ //cout << cols << " " << tiledCols << " " << tiledRows << endl;
+ //cout << "num blocks " << num_blocks << endl;
+
+ //cout << A << " " << out_col_normed << endl;
+ if(threshold > 0.0f)
+ kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
+ else
+ kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
+
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
+{
+ int threads = 256;
+ int items_per_thread = 8;
+ // we load 128 column values per warp
+ int tile_cols = 32*items_per_thread;
+ int tile_rows = 32;
+ int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
+ int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
+ int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
+ int outCols = fill_up_to_nearest_multiple(cols, 32);
+ int outRows = fill_up_to_nearest_multiple(rows, 32);
+ if(FORMAT == COL_TURING)
+ {
+ if(TRANSPOSE)
+ outRows = fill_up_to_nearest_multiple(cols, 8);
+ else
+ outRows = fill_up_to_nearest_multiple(rows, 8);
+ }
+ else if(FORMAT == COL_AMPERE)
+ {
+ if(TRANSPOSE)
+ outRows = fill_up_to_nearest_multiple(cols, 32);
+ else
+ outRows = fill_up_to_nearest_multiple(rows, 32);
+ }
+ else
+ {
+ if(TRANSPOSE)
+ {
+ outCols = fill_up_to_nearest_multiple(rows, 32);
+ outRows = cols;
+ }
+ }
+
+ //cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
+ //cout << "num blocks " << num_blocks << endl;
+
+ //cout << A << " " << out_col_normed << endl;
+ kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
+
+void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
+{
+
+ cusparseSpMatDescr_t descA;
+ cusparseDnMatDescr_t descB, descC;
+
+ float alpha = 1.0f;
+ float beta = 0.0f;
+ void *dBuffer = NULL;
+ size_t bufferSize = 0;
+
+ CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
+ A_rowidx, A_colidx, A_vals,
+ CUSPARSE_INDEX_32I,
+ CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
+ // Create dense matrix C
+ CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
+ CUDA_R_16F, CUSPARSE_ORDER_ROW) );
+ // Create dense matrix B
+ if(transposed_B)
+ {
+ int tmp = A_cols;
+ A_cols = B_cols;
+ B_cols = tmp;
+ }
+
+ CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
+ CUDA_R_16F, CUSPARSE_ORDER_ROW) );
+ // allocate an external buffer if needed
+ CHECK_CUSPARSE( cusparseSpMM_bufferSize(
+ handle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE,
+ transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
+ &alpha, descA, descB, &beta, descC, CUDA_R_32F,
+ CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
+ CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
+
+ // execute SpMM
+ CHECK_CUSPARSE( cusparseSpMM(handle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE,
+ transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
+ &alpha, descA, descB, &beta, descC, CUDA_R_32F,
+ CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
+
+ // destroy matrix/vector descriptors
+ CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
+ CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
+ CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
+ CUDA_CHECK_RETURN( cudaFree(dBuffer) );
+}
+
+template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{
+
+ kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
+ CUDA_CHECK_RETURN(cudaPeekAtLastError());
+}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
+template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+
+template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+
+template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
+template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
+
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 8fb4cec..4e719df 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -14,6 +14,11 @@
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
+#include <cublas_v2.h>
+#include <cublasLt.h>
+#include <cusparse.h>
+#include <vector>
+#include <functional>
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
@@ -25,6 +30,34 @@
#define THREADS_PER_BLOCKS (512)
+#define CHECK_CUSPARSE(value) { \
+ cusparseStatus_t _m_cudaStat = value; \
+ if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
+ fprintf(stderr, "Error %s at line %d in file %s\n", \
+ cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
+ exit(1); \
+ } }
+
+
+#define THREADS_PER_BLOCKS (512)
+
+
+inline void checkCudaStatus(cudaError_t status) {
+ if (status != cudaSuccess) {
+ printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
+ throw std::logic_error("cuda API failed");
+ }
+}
+
+inline int checkCublasStatus(cublasStatus_t status) {
+ if (status != CUBLAS_STATUS_SUCCESS) {
+ printf("cuBLAS API failed with status %d\n", status);
+ //throw std::logic_error("cuBLAS API failed");
+ return 1;
+ }
+ return 0;
+}
+
typedef enum Operations_t
{
ksmul = 0,
@@ -39,6 +72,57 @@ typedef enum Optimizer_t
ADAGRAD = 4,
} Optimizer_t;
+typedef enum Transform_t
+{
+ ROW = 0,
+ COL = 1,
+ COL32 = 2,
+ COL_TURING = 3,
+ COL_AMPERE = 4,
+} Transform_t;
+
+class Context
+{
+ public:
+ cublasHandle_t m_handle;
+
+ Context()
+ {
+ cublasHandle_t handle;
+ cublasCreate_v2(&handle);
+ m_handle = handle;
+ }
+
+};
+
+class ContextLt
+{
+ public:
+ cublasLtHandle_t m_handle;
+
+ ContextLt()
+ {
+ cublasLtHandle_t handle;
+ cublasLtCreate(&handle);
+ m_handle = handle;
+ }
+
+};
+
+class ContextCusparse
+{
+ public:
+ cusparseHandle_t m_handle;
+
+ ContextCusparse()
+ {
+ cusparseHandle_t handle;
+ cusparseCreate(&handle);
+ m_handle = handle;
+ }
+
+};
+
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
@@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
+void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
+void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long long int strideA, long long int strideB, long long int strideC, int batchCount);
+
+
+template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
+
+template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
+void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
+void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
+void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
+void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
+ int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
+
+template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
+
+void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
+
+template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
+
#endif
diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c
index c2fed6b..03c8d92 100644
--- a/csrc/pythonInterface.c
+++ b/csrc/pythonInterface.c
@@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#endif
+#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
+void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
+{ \
+ transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
+} \
+
+MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
+MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
+MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
+MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
+MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
+MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
+MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
+MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
+
+void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
+void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
+void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
+void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
+void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
+void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
+
+ int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_TURING, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_TURING, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_AMPERE, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_AMPERE, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt<COL_AMPERE, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
+void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
extern "C"
{
#if BUILD_CUDA
@@ -155,7 +201,86 @@ extern "C"
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
- #endif
+ void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
+ { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
+ void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
+ long strideA, long strideB, long strideC, int batchCount)
+ { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
+
+ Context *get_context(){ return new Context(); }
+ ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
+
+ int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+ //{ (cublasLtHandle_t)context->m_handle; return 0; }
+ //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
+ { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
+
+ #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
+ void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
+ { \
+ transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
+ } \
+
+ MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
+ MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
+ MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
+ MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
+ MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
+
+ void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
+ { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); }
+ void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
+ { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
+
+ void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
+ { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }
+
+ void ctransform_row2col32(char * A, char *out, int rows, int cols)
+ { transform_row2col32(A, out, rows, cols); }
+
+ void ctransform_row2col32T(char * A, char *out, int rows, int cols)
+ { transform_row2col32T(A, out, rows, cols); }
+
+ void ctransform_row2turing(char * A, char *out, int rows, int cols)
+ { transform_row2turing(A, out, rows, cols); }
+
+ void ctransform_row2turingT(char * A, char *out, int rows, int cols)
+ { transform_row2turingT(A, out, rows, cols); }
+
+ void ctransform_row2ampere(char * A, char *out, int rows, int cols)
+ { transform_row2ampere(A, out, rows, cols); }
+
+ void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
+ { transform_row2ampereT(A, out, rows, cols); }
+
+ void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
+ { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
+
+ void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+ { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
+ void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
+ { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
+
+#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
}
diff --git a/tests/test_autograd.py b/tests/test_autograd.py
new file mode 100644
index 0000000..d2b5d59
--- /dev/null
+++ b/tests/test_autograd.py
@@ -0,0 +1,270 @@
+import pytest
+
+import torch
+import bitsandbytes as bnb
+
+from itertools import product
+
+n = 1
+k = 25
+dim1 = torch.randint(16,64, size=(n,)).tolist()
+dim2 = torch.randint(32,96, size=(n,)).tolist()
+dim3 = torch.randint(32,96, size=(n,)).tolist()
+dim4 = torch.randint(32,96, size=(n,)).tolist()
+funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
+str_funcs = ['bmm', 'matmul']
+req_grad = [(False, False), (True, False), (True, True), (False, True)]
+req_grad_str = ['FF', 'TF', 'TT', 'FT']
+transpose = [(False, False), (False, True), (True, True), (True, False)]
+str_transpose = ['FF', 'FT', 'TT', 'TF']
+dtype = [torch.float32, torch.float16]
+values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
+str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
+def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ dim4 = dim4 - (dim4 % 16)
+ for i in range(k):
+
+ # normal multiply
+ if funcs[0] in [torch.mm, torch.matmul]:
+ dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
+ dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
+ A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
+ B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
+ target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ torch.nn.init.xavier_uniform_(B)
+
+ if not transpose[0] and not transpose[1]:
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B)
+ elif not transpose[0] and transpose[1]:
+ out_torch = funcs[0](A, B.t())
+ out_bnb = funcs[1](A, B.t())
+ elif transpose[0] and not transpose[1]:
+ out_torch = funcs[0](A.t(), B)
+ out_bnb = funcs[1](A.t(), B)
+ elif transpose[0] and transpose[1]:
+ out_torch = funcs[0](A.t(), B.t())
+ out_bnb = funcs[1](A.t(), B.t())
+
+ n = out_bnb.numel()
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.0175
+ idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
+ assert (idx==0).sum().item() < n*0.001
+
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+ torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+
+ # batched matrix multiply
+ if funcs[0] in [torch.bmm, torch.matmul]:
+ A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
+ B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
+ target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ torch.nn.init.xavier_uniform_(B)
+
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B)
+
+ n = out_bnb.numel()
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.01
+ torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
+
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+
+ if funcs[0] in [torch.matmul]:
+ dim1 = dim1 - (dim1 % 16)
+ A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
+ dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
+ B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
+ target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
+ torch.nn.init.xavier_uniform_(B)
+
+ if transpose[1]:
+ out_torch = funcs[0](A, B.t())
+ out_bnb = funcs[1](A, B.t())
+ else:
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B)
+
+ n = out_bnb.numel()
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.0175
+ idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
+ assert (idx==0).sum().item() < n*0.001
+
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+
+
+n = 1
+k = 3
+dim1 = torch.randint(16,64, size=(n,)).tolist()
+dim2 = torch.randint(32,96, size=(n,)).tolist()
+dim3 = torch.randint(32,96, size=(n,)).tolist()
+dim4 = torch.randint(32,96, size=(n,)).tolist()
+
+#dim1 = (17,)
+#dim2 = (7,)
+#dim3 = (37,)
+#dim4 = (23,)
+
+decomp = [0.0, 6.0]
+funcs = [(torch.matmul, bnb.matmul)]
+str_funcs = ['matmul']
+req_grad = [(False, False), (True, False), (True, True), (False, True)]
+req_grad_str = ['FF', 'TF', 'TT', 'FT']
+transpose = [(False, True), (False, False)]
+str_transpose = ['NT', 'NN']
+dtype = [torch.float16]
+has_fp16_weights = [True, False]
+values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
+str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
+def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
+ dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
+ dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
+ outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
+
+ for i in range(k):
+
+ # normal multiply
+ if funcs[0] in [torch.mm, torch.matmul]:
+ A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
+ if decomp == 6.0:
+ with torch.no_grad():
+ A[:, outlier_dim] = 6.0
+ B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
+ target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
+ torch.nn.init.xavier_uniform_(B)
+ B2 = B.clone()
+
+ state = bnb.MatmulLtState()
+ state.threshold = decomp
+ state.has_fp16_weights = has_fp16_weights
+ if not has_fp16_weights:
+ if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
+ state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
+ B2 = state.CB
+
+ if not transpose[0] and transpose[1]:
+ out_torch = funcs[0](A, B.t())
+ out_bnb = funcs[1](A, B2, state=state)
+ elif not transpose[0] and not transpose[1]:
+ out_torch = funcs[0](A, B)
+ out_bnb = funcs[1](A, B2.t(), state=state)
+
+ n = out_bnb.numel()
+ err = torch.abs(out_bnb-out_torch).mean().item()
+ #print(f'abs error {err:.4f}')
+ idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
+ assert (idx==0).sum().item() < n*0.0175
+ idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
+ assert (idx==0).sum().item() < n*0.001
+
+ if has_fp16_weights:
+ if any(req_grad):
+ out_bnb.data.copy_(out_torch)
+ torch.cuda.synchronize()
+ loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
+ loss_bnb.backward()
+ gradA1 = A.grad
+ gradB1 = B.grad
+ A.grad = None
+ B.grad = None
+
+ loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
+ loss_torch.backward()
+ gradA2 = A.grad
+ gradB2 = B.grad
+ A.grad = None
+ B.grad = None
+
+ if req_grad[0]:
+ torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
+ if req_grad[1]:
+ n = gradB1.numel()
+ assert torch.abs(gradB1).sum() > 0.0
+ assert torch.abs(gradB2).sum() > 0.0
+ idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.1
+ idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
+ assert (idx==0).sum().item() < n*0.02
+ torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
+
diff --git a/tests/test_functional.py b/tests/test_functional.py
index 2a7d308..6cbe58f 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,15 +1,76 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import pytest
+import math
+import random
+import time
import torch
import bitsandbytes as bnb
+import einops
from itertools import product
from bitsandbytes import functional as F
+torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
+k = 20
+
+def assert_all_approx_close(a, b, rtol, atol, count):
+ idx = torch.isclose(a, b, rtol, atol)
+ sumval = (idx==0).sum().item()
+ if sumval > count:
+ print(f'Too many values not close: assert {sumval} < {count}')
+ torch.testing.assert_allclose(a, b, rtol, atol)
+
+class FFN(torch.nn.Module):
+ def __init__(self, input_features, hidden_size, bias=True):
+ super(FFN, self).__init__()
+ self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
+ self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
+
+ with torch.no_grad():
+ torch.nn.init.xavier_uniform_(self.fc1.weight)
+ torch.nn.init.xavier_uniform_(self.fc2.weight)
+
+ def forward(self, x):
+ x = torch.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+class Timer(object):
+ def __init__(self):
+ self.starts = {}
+ self.ends = {}
+ self.agg = {}
+
+ def tick(self, name='default'):
+ if name not in self.starts:
+ self.starts[name] = torch.cuda.Event(enable_timing=True)
+ self.ends[name] = torch.cuda.Event(enable_timing=True)
+ self.starts[name].record()
+ else:
+ ms = self.tock(name, evict=True, print_ms=False)
+
+ def tock(self, name='default', evict=True, print_ms=True):
+ if name in self.ends:
+ self.ends[name].record()
+ torch.cuda.synchronize()
+ ms = self.starts[name].elapsed_time(self.ends[name])
+ if name not in self.agg: self.agg[name] = 0.0
+ self.agg[name] += ms
+ if evict:
+ self.starts.pop(name)
+ self.ends.pop(name)
+
+ if print_ms and name in self.agg:
+ print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0))
+
+ return self.agg[name]
+
+ def reset(self):
+ self.starts = {}
+ self.ends = {}
+ self.agg = {}
+ print('Resetting benchmark data')
+
def setup():
pass
@@ -64,8 +125,8 @@ def test_dynamic_quantization():
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diff.mean().item() < 0.0135
- print(sum(diffs)/len(diffs))
- print(sum(reldiffs)/len(reldiffs))
+ #print(sum(diffs)/len(diffs))
+ #print(sum(reldiffs)/len(reldiffs))
for i in range(100):
A1 = torch.rand(1024, 1024, device='cuda')
@@ -88,8 +149,8 @@ def test_dynamic_blockwise_quantization():
diffs.append(diff.mean().item())
reldiffs.append(reldiff.mean().item())
assert diffs[-1] < 0.011
- print(sum(diffs)/len(diffs))
- print(sum(reldiffs)/len(reldiffs))
+ #print(sum(diffs)/len(diffs))
+ #print(sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
@@ -125,7 +186,7 @@ def test_percentile_clipping(gtype):
n = 4
step = 0
percentile=5
- for i in range(1000):
+ for i in range(k):
step += 1
g = torch.randn(n, n, dtype=gtype, device='cuda')
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
@@ -145,69 +206,1653 @@ def test_percentile_clipping(gtype):
torch.testing.assert_allclose(gnorm1, gnorm2)
+def quant(x):
+ max1 = torch.abs(x).max()
+ x = torch.round(x/max1*127)
+ return max1, x.to(torch.int8)
+
+def dequant(c, maxC):
+ return c.float()*(maxC/127)
+
+def mm_dequant(maxA, maxB, C):
+ return C.float()*(maxA/127)*(maxB/127)
+
+def quant_multi(x, dim):
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ max1[max1==0] = 1.0
+ x = torch.round(x/max1*127)
+ return max1, x.to(torch.int8)
+
+def quant_multi_chunk(x, dim, chunk_size=32):
+ if dim==1:
+ x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size)
+ max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True)
+ max1 = torch.tile(max1, (1, 1, x.shape[1]))
+ max1 = max1.view(x.shape)
+ elif dim==0:
+ x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size)
+ max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
+ max1 = torch.tile(max1, (x.shape[0], 1, 1))
+ max1 = max1.view(x.shape)
+ max1[max1==0] = 1.0
+ x = torch.round(x/max1*127)
+ return max1, x.to(torch.int8)
+
+def quant_minmax(A):
+ minA = A.min()
+ maxA = A.max()
+
+def mean(xx):
+ return sum(xx)/float(len(xx))
+
+#dim1 = torch.randint(1,1024*4, size=(4,)).tolist()
+#dim2 = torch.randint(1,1024*4, size=(4,)).tolist()
+dim1 = [1024*2]
+dim2 = [1024*16]
+methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)]
+methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
+#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
+method_names = ['linear', 'vectorwise']
+batched = [False, True]
+values = list(product(dim1,dim2, methods, batched))
+values_names = list(product(dim1,dim2, method_names, batched))
+names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names]
+@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
+def test_approx_igemm(dim1, dim2, quant_methods, batched):
+ dim1 = dim1 - (dim1 % 32)
+ dim2 = dim2 - (dim2 % 32)
+ errors = []
+ relerrors = []
+ print('')
+ for i in range(5):
+ if batched:
+ A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda')
+ B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda')
+ maxA, Ac = quant_methods[0](A, 2)
+ maxB, Bc = quant_methods[1](B, 1)
+ else:
+ A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda')
+ B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda')
+ maxA, Ac = quant_methods[0](A, 1)
+ maxB, Bc = quant_methods[1](B, 0)
+ torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
+ if batched:
+ out2 = torch.bmm(A, B)
+ C = torch.bmm(Ac.float(), Bc.float())
+ else:
+ out2 = torch.mm(A, B)
+ C = F.igemm(Ac, Bc)
+ out = quant_methods[4](maxA, maxB, C)
+ std = out2.std()
+ out/= std
+ out2/= std
+ err = torch.abs(out-out2)
+ relerr = err/torch.abs(out2)
+ errors.append(err.mean().item())
+ relerrors.append(relerr.mean().item())
+ print(mean(errors))
+ print(mean(relerrors))
+
+
+
+
+
+
def test_stable_embedding():
layer = bnb.nn.StableEmbedding(1024, 1024)
layer.reset_parameters()
-def test_dynamic_blockwise_quantization_cpu():
- #A1 = torch.randn(1024, 1024, device='cpu')
- #code = F.create_dynamic_map()
- #for i in range(1000):
- # C, S = F.quantize_blockwise(A1, code=code)
- # A2 = F.dequantize_blockwise(C, S)
- for i in range(10):
- # equivalence with GPU blockwise quantization
- A1 = torch.randn(1024, 1024, device='cpu')
- C1, S1 = F.quantize_blockwise(A1)
- C2, S2 = F.quantize_blockwise(A1.cuda())
- torch.testing.assert_allclose(S1[0], S2[0].cpu())
- # there seems to be some issues with precision in CUDA vs CPU
- # not all elements are usually close, with couple off elements in a million
- idx = torch.isclose(C1, C2.cpu())
- assert (idx==0).sum().item() < 15
+n = 2
+hidden_dim = torch.randint(32,256, size=(n,)).tolist()
+batch_dim = torch.randint(16,256, size=(n,)).tolist()
+seq_dim = torch.randint(16,256, size=(n,)).tolist()
+transpose = [(False, False), (False, True), (True, False), (True, True)]
+values = list(product(hidden_dim,batch_dim, transpose, seq_dim))
+names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
+def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
+ hidden_dim = hidden_dim - (hidden_dim % 32)
+ batch_dim = batch_dim - (batch_dim % 16)
+ seq_dim = seq_dim - (seq_dim % 16)
+ for i in range(k):
+ shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
+ shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.float(), B.t().float())
+ out = F.igemm(A, B.t())
+ elif transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.t().float(), B.float())
+ out = F.igemm(A.t(), B)
+ elif transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.t().float(), B.t().float())
+ out = F.igemm(A.t(), B.t())
+ torch.testing.assert_allclose(out.float(), out2)
- diffs = []
- reldiffs = []
+ for i in range(k):
+ shapeA = (batch_dim, seq_dim, hidden_dim)
+ shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4)))
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.matmul(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.matmul(A.float(), B.t().float())
+ out = F.igemm(A, B.t())
+
+ torch.testing.assert_allclose(out.float(), out2)
+
+
+n = 3
+seq_dim = torch.randint(32,512, size=(n,)).tolist()
+hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
+batch_dim = torch.randint(2,16, size=(n,)).tolist()
+values = list(product(seq_dim,hidden_dim,batch_dim))
+names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
+def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
+ seq_dim = seq_dim - (seq_dim % 32)
+ hidden_dim = hidden_dim - (hidden_dim % 32)
+ batch_dim = batch_dim - (batch_dim % 2)
+ for i in range(25):
+ A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8)
+ out2 = torch.einsum('bsi, bso->io', A.float(), B.float())
+ iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
+ out = F.igemm(A, B, out=iout)
+
+ torch.testing.assert_allclose(out.float(), out2)
+
+n = 2
+seq_dim = torch.randint(32,512, size=(n,)).tolist()
+hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist()
+batch_dim = torch.randint(2,16, size=(n,)).tolist()
+transpose = [False, True]
+values = list(product(seq_dim,hidden_dim,batch_dim, transpose))
+names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
+def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
+
+ def min_max(x):
+ maxA = torch.amax(x, dim=2, keepdim=True)
+ minA = torch.amin(x, dim=2, keepdim=True)
+ scale = (maxA-minA)/2.0
+ return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale
+
+ seq_dim = seq_dim - (seq_dim % 16)
+ hidden_dim = hidden_dim - (hidden_dim % 16)
+ batch_dim = batch_dim - (batch_dim % 2)
+ errs = []
+ relerrs = []
+ errs2 = []
+ relerrs2 = []
+ for i in range(k):
+ A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda')
+ if transpose:
+ B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda')
+ else:
+ B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda')
+ Ac, minA, scale = min_max(A)
+ if transpose:
+ maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
+ out = F.igemm(Ac, Bc.t())
+ out2 = torch.matmul(A,B.t())
+ offset = B.t().sum(0)*(minA+scale)
+ out = out.float()
+ out = (out*maxB.t()*scale/(127*127))+offset
+
+ maxA, Ac = quant_multi(A, dim=2)
+ out3 = F.igemm(Ac, Bc.t())
+ out3 = mm_dequant(maxA, maxB.t(), out3)
+ else:
+ maxB, Bc = quant_multi(B, dim=0)
+ offset = B.sum(0)*(minA+scale)
+ out = F.igemm(Ac, Bc)
+ out2 = torch.matmul(A,B)
+ out = out.float()
+ out = (out*maxB*scale/(127*127))+offset
+
+ maxA, Ac = quant_multi(A, dim=2)
+ out3 = F.igemm(Ac, Bc)
+ out3 = mm_dequant(maxA, maxB, out3)
+
+ std = out2.std()
+ out2 /= std
+ out /= std
+ out3 /= std
+
+ err = torch.abs(out-out2)
+ relerr = err/(torch.abs(out2)+1e-7)
+
+ err2 = torch.abs(out3-out2)
+ relerr2 = err2/(torch.abs(out2)+1e-7)
+
+ errs.append(err.mean().item())
+ relerrs.append(relerr.mean().item())
+ errs2.append(err2.mean().item())
+ relerrs2.append(relerr2.mean().item())
+ #print(mean(errs))
+ #print(mean(relerrs))
+ #print(mean(errs2))
+ #print(mean(relerrs2))
+ assert mean(errs) < 0.015
+ assert mean(relerrs) < 0.3
+
+n = 2
+dim1 = torch.randint(1,64, size=(n,)).tolist()
+dim2 = torch.randint(32,128, size=(n,)).tolist()
+dim3 = torch.randint(32,256, size=(n,)).tolist()
+dim4 = torch.randint(32,256, size=(n,)).tolist()
+transpose = [(False, False), (True, False), (False, True), (True, True)]
+values = list(product(dim1,dim2,dim3,dim4,transpose))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names)
+def test_ibmm(dim1, dim2, dim3, dim4, transpose):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ dim4 = dim4 - (dim4 % 16)
+ for i in range(k):
+ shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
+ shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
+ A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8)
+
+ if not transpose[0] and not transpose[1]:
+ out2 = torch.bmm(A.float(), B.float())
+ out = F.igemm(A, B)
+ elif not transpose[0] and transpose[1]:
+ out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
+ out = F.igemm(A, B.permute([0, 2, 1]))
+ elif transpose[0] and not transpose[1]:
+ out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
+ out = F.igemm(A.permute([0, 2, 1]), B)
+ elif transpose[0] and transpose[1]:
+ out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
+ out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
+ torch.testing.assert_allclose(out.float(), out2.float())
+
+n = 1
+dim1 = torch.randint(1,64, size=(n,)).tolist()
+dim2 = torch.randint(32,128, size=(n,)).tolist()
+dim3 = torch.randint(32,256, size=(n,)).tolist()
+values = list(product(dim1,dim2,dim3))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
+def test_vector_quant(dim1, dim2, dim3):
+ dim2 = dim2 - (dim2 % 16)
+ dim3 = dim3 - (dim3 % 16)
+ for i in range(k):
+ A = torch.randn(size=(dim2, dim3), device='cuda')
+ qA, SA = F.vectorwise_quant(A, dim=0)
+ A1 = F.vectorwise_dequant(qA, SA)
+ torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1)
+
+
+
+n = 2
+dim1 = torch.randint(2,256, size=(n,)).tolist()
+dim2 = torch.randint(2,256, size=(n,)).tolist()
+dim3 = torch.randint(2,256, size=(n,)).tolist()
+#dim1, dim2 = (256,), (256,)
+dtype = [torch.int8, torch.int32]
+a_order = ['row']
+out_order = ['col', 'row', 'col32']
+transpose = [False]
+dims = [2, 3]
+values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
+
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
+def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
+ if dims == 3 and out_order != 'col32': return
+ if dtype == torch.int32 and out_order != 'col32': return
+ func = F.get_transform_func(dtype, orderA, orderOut, transpose)
+
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+
+ out, S = F.nvidia_transform(A, to_order=orderOut)
+
+ if orderOut == 'row':
+ torch.testing.assert_allclose(A.flatten(), out.flatten())
+ elif orderOut == 'col':
+ torch.testing.assert_allclose(A.t().flatten(), out.flatten())
+ elif orderOut == 'col32':
+ if dims == 2:
+ n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32)))
+ elif dims == 3:
+ n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32)))
+ assert out.numel() == n
+ elif orderOut == 'col_turing':
+ # 32 col 8 row tiles
+ n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32)))
+ assert out.numel() == n
+ total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0)
+ for row in range(A.shape[0]):
+ for col in range(A.shape[1]):
+ i = row*A.shape[1]
+ j = col
+
+ coltile = (col // 32) + (1 if col % 32 != 0 else 0)
+ rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile
+ offset = 32*8*(rowtile+coltile)
+ col2 = col % 32
+ row2 = (row%8)*32
+
+
+ assert A.flatten()[i+j] == A[row, col]
+ #assert A.flatten()[i+j] == out.flatten()[row2+col2]
+ #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
+ #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
+
+ if orderOut == 'col32':
+ out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S)
+ torch.testing.assert_allclose(A, out2)
+
+
+n = 1
+dim1 = torch.randint(1,256, size=(n,)).tolist()
+dim2 = torch.randint(32,512, size=(n,)).tolist()
+dim3 = torch.randint(32,1024, size=(n,)).tolist()
+dim4 = torch.randint(32,1024, size=(n,)).tolist()
+
+#dim1 = [2]
+#dim2 = [2]
+#dim3 = [2]
+#dim4 = [2]
+
+dims = (2,3)
+ldb = [0]
+#ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1,dim2,dim3,dim4,dims, ldb))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names)
+def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
+ elif dims == 3:
+ A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
+ B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
+ C1 = torch.matmul(A.float(), B.t().float())
+
+ A2, SA = F.transform(A, 'col32')
+ B2, SB = F.transform(B, 'col_turing')
+ C2, SC = F.igemmlt(A2, B2, SA, SB)
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ torch.testing.assert_allclose(C1, C3.float())
+
+ # transpose
+ B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ C1 = torch.matmul(A.float(), B.float())
+
+ B2t, SBt = F.transform(B, 'col_turing', transpose=True)
+ C2, SC = F.igemmlt(A2, B2t, SA, SBt)
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ torch.testing.assert_allclose(C1, C3.float())
+
+dim1 = [32]
+dim2 = [32]
+dim3 = [32]
+dim4 = [32]
+
+dims = (2,)
+#ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1,dim2,dim3,dim4,dims))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names)
+def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
+ formatB = F.get_special_format_str()
+ for i in range(k):
+ if dims == 2:
+ A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half()
+ elif dims == 3:
+ A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half()
+ B = torch.randn((dim4, dim3), device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ C1 = torch.matmul(A, B.t())
+ C2 = bnb.matmul(A, B.t())
+
+ A = A.view(-1, A.shape[-1])
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
+ C32A, SA = F.transform(CA, 'col32')
+ CxB, SB = F.transform(CB, to_order=formatB)
+ out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB)
+ output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt)
+
+ #print('')
+ #print(output.flatten()[:10])
+ #print(C1.flatten()[:10])
+ #print(C2.flatten()[:10])
+
+
+ #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
+
+ # transpose
+ #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
+ #C1 = torch.matmul(A.float(), B.float())
+
+ #B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
+ #C2, SC = F.igemmlt(A2, B2t, SA, SBt)
+ #C3, S = F.transform(C2, 'row', state=SC)
+ #torch.testing.assert_allclose(C1, C3.float())
+
+batch_size = 2
+seqdim = 512
+#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)]
+
+
+#values = list(product(batch, seq, model, hidden))
+names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
+def test_bench_8bit_training(batch, seq, model, hidden):
+ formatB = F.get_special_format_str()
+ A = torch.randn(batch, seq, model, device='cuda').half()
+ grad = torch.randn(batch, seq, model, device='cuda').half()
+ w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half()
+ w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half()
+ print('')
+
+ #torch.cuda.synchronize()
+ ## warmup
+ #for i in range(100):
+ # torch.matmul(A, w1.t())
+ #torch.cuda.synchronize()
+
+ dtype = torch.int8
+ A = A.view(-1, A.shape[-1]).contiguous()
+ grad = grad.view(-1, grad.shape[-1]).contiguous()
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+
+ out1 = torch.matmul(A, w1.t()) # fc1
+ #out2 = torch.matmul(out1, w2.t())# fc2
+
+ #d1 = torch.matmul(grad, w2) # delta1
+ #d2 = torch.matmul(d1, w1) # delta2
+
+ #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
+ #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
+
+ torch.cuda.synchronize()
+ t16 = time.time() - t0
+ print(t16)
+
+ #torch.cuda.empty_cache()
+
+ #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+
+ #CTw1, Sw1 = F.transform2(Cw1, formatB)
+ #CTw2, Sw2 = F.transform2(Cw2, formatB)
+ #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+
+ #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ #C32A, SA = F.transform2(CA, 'col32')
+ ## fc1
+ #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
+ ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
+
+ ## fc2
+ #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
+ #C32out1, Sout1 = F.transform2(Cout1, 'col32')
+ #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
+ ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
+
+ ## delta1
+ #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
+ #C32grad, Sgrad = F.transform2(Cgrad, 'col32')
+ ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
+ ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
+
+ ## delta2
+ #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
+ #C32d1, Sd1 = F.transform2(Cd1, 'col32')
+ ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
+ ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
+
+ ## grad1
+ #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
+ #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
+ ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
+ ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
+
+ ## grad2
+ #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
+ #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
+ ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
+ ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
+
+ #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+
+ #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
+
+ #CTw1, Sw1 = F.transform2(Cw1, formatB)
+ #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ #CTw2, Sw2 = F.transform2(Cw2, formatB)
+ #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+ #torch.cuda.synchronize()
+ #t0 = time.time()
+ #for i in range(k):
+ # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # #CTw1, Sw1 = F.transform2(Cw1, formatB)
+ # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ # #CTw1, Sw1 = F.transform2(Cw1, formatB)
+
+ # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
+ # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
+ # #CTw2, Sw2 = F.transform2(Cw2, formatB)
+ # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
+
+ # C32A, SA = F.transform2(CA, 'col32')
+
+ # # fc1
+ # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
+ # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
+
+ # #print(coo_tensor.nnz)
+ # #out1sp = F.spmm_coo(coo_tensor, w1.t())
+ # #print(w1.t().shape)
+ # #out1 = out1dn + out1sp
+
+ # # fc2
+ # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
+ # C32out1, Sout1 = F.transform2(Cout1, 'col32')
+ # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
+ # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
+
+ # # delta1
+ # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
+ # C32grad, Sgrad = F.transform2(Cgrad, 'col32')
+ # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
+ # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
+
+ # # delta2
+ # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
+ # C32d1, Sd1 = F.transform2(Cd1, 'col32')
+ # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
+ # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
+
+ # # grad1
+ # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
+ # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
+ # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
+ # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
+
+ # ## grad2
+ # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
+ # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
+ # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
+ # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
+
+ #torch.cuda.synchronize()
+ #t8 = time.time() - t0
+ #print(t8)
+
+
+
+
+
+n = 2
+dim1 = torch.randint(64,256, size=(n,)).tolist()
+dim4 = torch.randint(64,1024, size=(n,)).tolist()
+
+#dim1 = [2*1024]
+#dim4 = [2*1024]
+
+#dim1 = [4]
+#dim4 = [4]
+
+dims = (2,)
+#ldb = list(range(256, 1*1024, 256))
+formatB = ['col_turing', 'col_ampere']
+values = list(product(dim1,dim4,dims, formatB))
+names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
+def test_dequant_mm(dim1, dim4, dims, formatB):
+ inner = torch.randint(1, 128, size=(1,)).item()
+ formatB = F.get_special_format_str()
+ for i in range(k):
+ A = torch.randn(dim1, inner, device='cuda')
+ B = torch.randn(dim4, inner, device='cuda')
+ C1 = torch.matmul(A.half(), B.t().half())
+
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+ B1, maxB = F.vectorwise_quant(B, dim=1)
+
+ A2, SA = F.nvidia_transform(A1, 'col32')
+ B2, SB = F.nvidia_transform(B1, formatB)
+ C2, SC = F.igemmlt(A2, B2, SA, SB)
+
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
+
+ count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item()
+ n = C1.numel()
+ p = 0.06
+ assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}'
+
+ C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten())
+ torch.testing.assert_allclose(C5, C4)
+ #print(C2)
+
+
+
+n = 2
+dim1 = [1*1024]
+dim2 = [1*1024]
+#dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+
+dims = (2,)
+#ldb = list(range(256, 1*1024, 256))
+values = list(product(dim1,dim2,dims))
+names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
+def test_colrow_absmax(dim1, dim2, dims):
+ for i in range(k):
+ threshold = 3.0
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ A_truncated = A.clone()
+ A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0
+ if dims == 2:
+ row_stats1, _ = torch.abs(A.float()).max(1)
+ col_stats1, _ = torch.abs(A.float()).max(0)
+ row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1)
+ col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0)
+ else:
+ assert False
+
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold)
+
+ A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4)
+ nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten()
+ nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device)
+ nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
+
+ torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
+ torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
+ torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
+
+ row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
+
+ torch.testing.assert_allclose(col_stats1, col_stats2)
+ torch.testing.assert_allclose(row_stats1, row_stats2)
+ assert nnz_block_ptr2 is None
+
+
+
+n = 2
+#dim1 = [8*1024]
+#dim2 = [4*1024]
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+
+values = list(product(dim1,dim2))
+names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2", values, ids=names)
+def test_double_quant(dim1, dim2):
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ out_col1, Scol = F.vectorwise_quant(A, dim=0)
+ out_row1, Srow = F.vectorwise_quant(A, dim=1)
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+
+ # max difference is 1 due to rounding differences
+ torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
+ torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
+
+
+ n = CAt.numel()
+ num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item()
+ num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item()
+
+ # allow for 1:500 error due to rounding differences
+ min_error = 1/500
+ if num_not_close_cols > (min_error*n):
+ print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}')
+ assert False
+ if num_not_close_rows > (min_error*n):
+ print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}')
+ assert False
+
+ torch.testing.assert_allclose(Srow.flatten(), statsA)
+ torch.testing.assert_allclose(Scol.flatten(), statsAt)
+
+
+n = 4
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
+inner = torch.randint(1,4*1024, size=(n,)).tolist()
+
+dim1 = [6]
+dim4 = [4]
+inner = [8]
+
+values = list(zip(dim1, dim4, inner))
+names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+def test_integrated_igemmlt(dim1, dim4, inner):
+ for i in range(k):
+ A = torch.randn(dim1, inner, device='cuda').half()
+ B = torch.randn(dim4, inner, device='cuda').half()
+
+ out1 = torch.matmul(A.half(), B.t().half())
+
+ C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
+ C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+ B1, maxB = F.vectorwise_quant(B, dim=1)
+
+ torch.testing.assert_allclose(maxA.flatten(), stats1a)
+ torch.testing.assert_allclose(maxB.flatten(), stats2a)
+ torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
+ torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
+
+ A2, SA = F.nvidia_transform(C1a, 'col32')
+ B2, SB = F.nvidia_transform(C2a, 'col_turing')
+ outC32, SC = F.igemmlt(A2, B2, SA, SB)
+ out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
+
+ A2, SA = F.nvidia_transform(A1, 'col32')
+ B2, SB = F.nvidia_transform(B1, 'col_turing')
+ C2, SC = F.igemmlt(A2, B2, SA, SB)
+
+ C3, S = F.nvidia_transform(C2, 'row', state=SC)
+ out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t())
+
+ err1 = torch.abs(out1-out2).mean().item()
+ err2 = torch.abs(out1-out3).mean().item()
+ assert err2 <= err1*1.01
+
+
+n = 6
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim4 = torch.randint(1,4*1024, size=(n,)).tolist()
+inner = torch.randint(1,4*1024, size=(n,)).tolist()
+
+values = list(zip(dim1, dim4, inner))
+names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+def test_igemmlt_row_scale(dim1, dim4, inner):
+ formatB = F.get_special_format_str()
+ err1, err2, err3 = [], [], []
+ relerr1, relerr2 = [], []
+ scale = 1
+ for i in range(k):
+ A = torch.randn(dim1, inner, device='cuda').half()
+ B = torch.randn(dim4, inner, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ C1 = torch.matmul(A, B.t())
+
+ out1 = torch.matmul(A.half(), B.t().half())
+
+
+ C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
+ CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
+ A2, SA = F.nvidia_transform(C1a, 'col32')
+ B2, SB = F.nvidia_transform(CB, formatB)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+
+ c = 10.0*inner*scale
+ row_scale = torch.ones_like(maxA)/c
+ outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ C3, S = F.nvidia_transform(outC32, 'row', state=SC)
+ maxval = torch.abs(C3).max()
+ if maxval == 127:
+ scale = 1.5
+ else:
+ scale = maxval/120
+ out3 = C3*maxA*absmaxB*c/(127*127)
+
+ C4 = torch.matmul(C1a.float(), CB.float().t())
+
+
+ C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
+ B2, SB = F.nvidia_transform(C2a, formatB)
+ outC32, SC = F.igemmlt(A2, B2, SA, SB)
+ out2 = F.mm_dequant(outC32, SC, stats1a, stats2a)
+
+ CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector')
+ CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+
+ C = torch.matmul(CA.float(), CB.t().float())
+ out4 = C*SA*SB/(127*127)
+ #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127)
+
+ #print('='*80)
+ #print(out1)
+ #print(out2)
+ #print(out3)
+
+ #print(out1)
+ #print(out2)
+ #print(out3)
+ err1.append(torch.abs(out1-out2).mean().item())
+ err2.append(torch.abs(out1-out3).mean().item())
+ err3.append(torch.abs(out1-out4).mean().item())
+
+ #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10)
+ print('')
+ print(sum(err1)/len(err1))
+ print(sum(err2)/len(err2))
+ print(sum(err3)/len(err3))
+
+
+dim1 = [1024, 2048]
+inner = [12288*4, 4096*4]
+dim4 = [12288, 4096]
+
+values = list(zip(dim1, dim4, inner))
+names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
+def test_row_scale_bench(dim1, dim4, inner):
+ err1, err2, err3 = [], [], []
+ relerr1, relerr2 = [], []
+ scale = 1
+ A = torch.randn(dim1, inner, device='cuda').half()
+ B = torch.randn(dim4, inner, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ # warmpup
+ for i in range(k):
+ C1 = torch.matmul(A, B.t())
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ C1 = torch.matmul(A, B.t())
+ torch.cuda.synchronize()
+ print('16', time.time()-t0)
+
+ C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A)
+ CB, absmaxB = F.vectorwise_quant(B, quant_type='linear')
+ A2, SA = F.nvidia_transform(C1a, 'col32')
+ B2, SB = F.nvidia_transform(CB, formatB)
+ A1, maxA = F.vectorwise_quant(A, dim=1)
+
+ c = 10.0*inner*scale
+ row_scale = maxA/c
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
+ torch.cuda.synchronize()
+ print('row-wise', time.time()-t0)
+
+
+ C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B)
+ B2, SB = F.nvidia_transform(C2a, formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ outC32, SC = F.igemmlt(A2, B2, SA, SB)
+ torch.cuda.synchronize()
+ print('vector-wise', time.time()-t0)
+
+
+
+
+n = 2
+dim1 = torch.randint(2,1024, size=(n,)).tolist()
+dim2 = torch.randint(2,1024, size=(n,)).tolist()
+#dim1 = [8*1024]
+#dim2 = [4*1024]
+
+dim3 = [0]
+dtype = [torch.int8]
+a_order = ['row']
+out_order = ['col32', 'col_turing', 'col_ampere']
+transpose = [False, True]
+dims = [2]
+values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose))
+names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names)
+def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
+ for i in range(k):
+ if dims == 2:
+ A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype)
+ elif dims == 3:
+ A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype)
+
+ A.view(-1)[-1] = -1
+ if transpose:
+ At = A.t().contiguous()
+ out1, S1 = F.nvidia_transform(At, to_order=orderOut)
+ else:
+ out1, S1 = F.nvidia_transform(A, to_order=orderOut)
+ out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose)
+
+ assert S1[0][0] == S2[0][0]
+ assert S1[0][1] == S2[0][1]
+ #print(out1)
+ #print(out2)
+
+ torch.testing.assert_allclose(out1, out2)
+
+n = 2
+#dim1 = torch.randint(2,1024, size=(n,)).tolist()
+#dim2 = torch.randint(2,1024, size=(n,)).tolist()
+dim1 = [1]
+dim2 = [33]
+
+dtype = [torch.int8]
+#a_order = ['col_turing', 'col_ampere']
+a_order = ['col_turing']
+out_order = ['row']
+values = list(product(dim1,dim2,dtype, a_order, out_order))
+names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
+def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
+ for i in range(1):
+ A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype)
+
+ out2, S2 = F.transform(A, to_order=orderA)
+ A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2)
+ assert A2.shape[0] == A.shape[0]
+ assert A2.shape[1] == A.shape[1]
+
+
+ print('')
+ print(A)
+ print(out2)
+ print(A2)
+
+
+ #torch.testing.assert_allclose(A, A2)
+
+
+
+
+def test_overflow():
+ formatB = F.get_special_format_str()
+ for i in range(2):
+ a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
+ b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 )
+
+ Ca, Sa = F.nvidia_transform(a, 'col32')
+ Cb, Sb = F.nvidia_transform(b, formatB)
+
+ c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8)
+ c2 = torch.matmul(a.float(), b.float().t())
+
+
+n = 2
+dim1 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+#dim1 = [4]
+#dim2 = [5]
+
+values = list(product(dim1,dim2))
+names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2", values, ids=names)
+def test_coo_double_quant(dim1, dim2):
+ threshold = 3.00
+ for i in range(k):
+ A = torch.randn(dim1, dim2, device='cuda').half()
+
+ idx = (torch.abs(A) >= threshold)
+ CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+
+ if coo_tensor is not None:
+ A1 = A*idx
+ A2 = torch.zeros_like(A)
+ A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
+ torch.testing.assert_allclose(A1, A2)
+
+ A1 = A*(idx==0)
+ A2 = (CA.float()*statsA.unsqueeze(1)/127).half()
+ torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2)
+
+n = 2
+dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+dim2 = torch.randint(1,1*1024, size=(n,)).tolist()
+#dim1 = [7]
+#dim2 = [11]
+transposed_B = [False, True]
+values = list(product(dim1,dim2, transposed_B))
+names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
+def test_spmm_coo(dim1, dim2, transposed_B):
+ threshold = 1.5
+ dim3 = torch.randint(32, 128, size=(1,)).item()
+ #dim3 = 17
+ for i in range(k):
+ A = torch.randn(dim1, dim2).cuda().half()
+ if transposed_B:
+ B = torch.randn(dim3, dim2).cuda().half()
+ else:
+ B = torch.randn(dim2, dim3).cuda().half()
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+
+ if transposed_B:
+ out2 = F.spmm_coo(cooA, B.t())
+ out1 = torch.matmul(A2, B.t())
+ else:
+ out2 = F.spmm_coo(cooA, B)
+ out1 = torch.matmul(A2, B)
+
+ assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30)
+
+
+
+def test_spmm_bench():
+ batch = 2
+ model = 1024*1
+ hidden = model*4
+ seq = 1024
+ dim1 = batch*seq
+ dim2 = model
+ dim3 = hidden
+ threshold = 4
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ B = torch.randn(dim2, dim3, device='cuda').half()
for i in range(10):
- A1 = torch.randn(1024, 1024, device='cpu')
- C, S = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2)
- reldiff = diff/torch.abs(A1+1e-8)
- diffs.append(diff.mean().item())
- reldiffs.append(reldiff.mean().item())
- assert diffs[-1] < 0.011
- #print(sum(diffs)/len(diffs))
- #print(sum(reldiffs)/len(reldiffs))
+ C1 = bnb.matmul(A, B)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ C1 = bnb.matmul(A, B)
+ torch.cuda.synchronize()
+ t8 = time.time()-t0
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ print(nnz/idx.numel())
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
- diffs = []
for i in range(10):
- A1 = torch.rand(1024, 1024, device='cpu')
- C, S = F.quantize_blockwise(A1)
- A2 = F.dequantize_blockwise(C, S)
- diff = torch.abs(A1-A2).mean().item()
- assert diff < 0.0033
- diffs.append(diff)
- torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
- #print(sum(diffs)/len(diffs))
+ out2 = F.spmm_coo(cooA, B)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(k):
+ out2 = F.spmm_coo(cooA, B)
+ torch.cuda.synchronize()
+ tsp = time.time()-t0
+ print(tsp, t8)
+ print(tsp/t8)
+
+
+n = 2
+dim1 = torch.randint(256,1*1024, size=(n,)).tolist()
+dim2 = torch.randint(256,1*1024, size=(n,)).tolist()
+values = list(product(dim1,dim2))
+names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2", values, ids=names)
+def test_integrated_sparse_decomp(dim1, dim2):
+ threshold = 3.0
+ formatB = 'col_turing'
+ for i in range(k):
+ A = torch.randn(dim1, dim2).cuda().half()
+ w1 = torch.randn(dim1, dim2).cuda().half()
+ out1 = torch.matmul(A, w1.t())
+
+ Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
+ CTw1, Sw1 = F.transform(Cw1, formatB)
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
+ C32A, SA = F.transform(CA, 'col32')
+
+ out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
+ out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
+
+ CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
+ C32A, SA = F.transform(CA, 'col32')
+
+ out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
+ out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
+
+ assert coo_tensor is not None
+
+ out4 = F.spmm_coo(coo_tensor, w1.t())
+ out5 = out3 + out4
+
+ err1 = torch.abs(out1-out2).mean().item()
+ err2 = torch.abs(out1-out5).mean().item()
+ assert err2 < err1
+
+
+def test_matmuls():
+ a = torch.randn(256, 256).half().cuda()
+ b = torch.randn(256, 256).half().cuda()
+ c1 = torch.matmul(a, b)
+ c2 = bnb.matmul(a, b)
+ c3 = bnb.matmul(a, b)
+
+ err1 = torch.abs(c1-c2).mean().item()
+ err2 = torch.abs(c1-c3).mean().item()
+ assert err1 < 0.2
+ assert err2 < 0.2
+
+
+
+n = 2
+#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1*2048]
+dim2 = [12288]
+#dim1 = [32]
+#dim2 = [32]
+#dtype = [torch.float16, torch.int8]
+dtype = [torch.float16]
+out_function = ['zeros', 'ones']
+values = list(product(dim1,dim2, dtype, out_function))
+names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
+def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
+ out_func = getattr(torch, out_func)
+
+ threshold = 3.3
+ #threshold = 2.8
+ #threshold = 0.0
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ if dtype == torch.float16:
+ B = torch.randn(dim2, dim2*4, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ else:
+ B = torch.randn(dim2, dim2*4, device='cuda').half()
+ torch.nn.init.xavier_uniform_(B)
+ B, SB = F.vectorwise_quant(B, quant_type='linear')
+ #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8)
+
+ print('')
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ out1 = torch.matmul(A2.half(), B.half())
+ out = out_func(out1.shape, dtype=torch.float16, device=out1.device)
+ out1 += out.clone()
+ out2 = F.spmm_coo_very_sparse(cooA, B, out=out)
+ #print(B)
+ #print(out1)
+ #print(out2)
+ p = 200/(2048*12288*4)
+ n = out1.numel()
+ count = math.ceil(p*n)
+ std = out1.std()
+ out1 /= std
+ out2 /= std
+ assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
+ #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
+
+ idx_col = torch.randint(0, A2.shape[-1], size=(15,))
+
+ #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
+
+ #Bt = torch.randn(dim2*4, dim2, device='cuda').half()
+ #torch.cuda.synchronize()
+ #t0 = time.time()
+ #print(A2.shape, B.shape)
+ #for i in range(100):
+ # #out3 = F.spmm_coo(cooA, Bt.t())
+ # #out2 = F.spmm_coo(cooA, B)
+ # #out2 = F.spmm_coo_very_sparse(cooA, B)
+ # #out1 = torch.matmul(A, Bt.t())
+
+ #torch.cuda.synchronize()
+ #print(time.time() - t0)
+
+def test_layout():
+ a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16)
+ a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte()
+ a2, s2 = F.transform(a1, 'col_turing')
+ print(a2.shape)
+
+ print(a1.flatten()[8*64:8*64+32])
+ for i in range(4):
+ print(a2.flatten()[i*8*32:i*8*32+32], 0)
+
+
+def test_coo2csr():
+ threshold = 1
+ A = torch.randn(128, 128).half().cuda()
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ csrA = F.coo2csr(cooA)
+ counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
+ assert counts.numel() == A.shape[0]
+
+ torch.testing.assert_allclose(counts, (A2!=0).sum(1))
+ idx = (A2!=0)
+ torch.testing.assert_allclose(A2[idx], csrA.values)
+
+
+def test_coo2csc():
+ threshold = 1
+ A = torch.randn(128, 128).half().cuda()
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ cscA = F.coo2csc(cooA)
+ counts = cscA.colptr[1:] - cscA.colptr[:-1]
+ assert counts.numel() == A.shape[1]
+
+ torch.testing.assert_allclose(counts, (A2!=0).sum(0))
+ # torch uses row-major -> use transpose to transfer to col-major
+ idx = (A2.t()!=0)
+ torch.testing.assert_allclose(A2.t()[idx], cscA.values)
+
+
+
+n = 2
+#dim1 = torch.randint(1,1*1024, size=(n,)).tolist()
+#dim2 = torch.randint(1,4*1024, size=(n,)).tolist()
+dim1 = [1*2048]
+#dim2 = [12288]
+dim2 = [2048]
+#dim1 = [2]
+#dim2 = [2]
+dtype = [torch.int8]
+values = list(product(dim1,dim2, dtype))
+names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
+def test_spmm_coo_dequant(dim1, dim2, dtype):
+ threshold = 6.0
+ #threshold = 2.8
+ #threshold = 0.0
+ A = torch.randn(dim1, dim2, device='cuda').half()
+ B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16)
+ torch.nn.init.xavier_uniform_(B)
+ Bt = B.t().contiguous()
+
+
+ CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B)
+
+ rowidx = torch.randint(0, A.shape[-1], size=(15,))
+
+ A[:, rowidx] = 8.0
+
+ idx = torch.abs(A) >= threshold
+ nnz = (idx == 1).sum().item()
+ rows, cols = torch.where(idx)
+ values = A[idx]
+ cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values)
+ A2 = A*idx
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out1 = torch.matmul(A2, B.half())
+ out3 = F.spmm_coo_very_sparse(cooA, CBt.half())
+ out3 = out3*statsBt.half()/127
+
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ offset = counts.cumsum(0).int()
+ max_count, max_idx = torch.sort(counts, descending=True)
+ print(torch.median(max_count.float()))
+
+ torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
+
+ p = 200/(2048*12288*4)
+ n = out1.numel()
+ count = math.ceil(p*n)
+ assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count)
+
+
+
+ #torch.cuda.synchronize()
+ #t0 = time.time()
+ #for i in range(100):
+ # out2 = F.spmm_coo_very_sparse(cooA, B)
+ #torch.cuda.synchronize()
+ #print('fp16', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo(cooA, B)
+ torch.cuda.synchronize()
+ print('cusparse fp16', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo_very_sparse(cooA, CBt)
+ torch.cuda.synchronize()
+ print('int8', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ torch.cuda.synchronize()
+ print('int8+dequant', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out2 = torch.matmul(A, B)
+ torch.cuda.synchronize()
+ print('matmul', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt)
+ out = out1+out2
+ torch.cuda.synchronize()
+ print('sparse+ matmul', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1)
+ torch.cuda.synchronize()
+ print('partial matmul', time.time() - t0)
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out1 = bnb.matmul(A, Bt)
+ torch.cuda.synchronize()
+ print('partial matmul', time.time() - t0)
+
+batch_size = 1
+seqdim = 2048
+values = []
+values.append((batch_size, seqdim, 768, 4*768))
+#values.append((batch_size, seqdim, 1024, 4*1024))
+#values.append((batch_size, seqdim, 1536, 4*1536))
+#values.append((batch_size, seqdim, 2048, 4*2048))
+#values.append((batch_size, seqdim, 2560, 4*2560))
+#values.append((batch_size, seqdim, 4096, 4*4096))
+#values.append((batch_size, seqdim, 5140, 4*5140))
+#values.append((batch_size, seqdim, 12288, 4*12288))
+names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values]
+@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
+def test_bench_matmul(batch, seq, model, hidden):
+ formatB = F.get_special_format_str()
+
+ A = torch.randn(batch, seq, model, device='cuda').half()
+ B = torch.empty(hidden, model, dtype=torch.float16, device='cuda')
+ torch.nn.init.xavier_uniform_(B)
+
+ linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
+ linear8bit.eval()
+
+ outliers = torch.randint(0, model, size=(5,)).cuda()
+ A[:, :, outliers] = 8.0
+
+ linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
+ linearMixedBit.eval()
+
+ # warmup
+ for i in range(100):
+ torch.matmul(A, B.t())
+ torch.cuda.synchronize()
+ print('')
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ torch.matmul(A, B.t())
+ torch.cuda.synchronize()
+ print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ bnb.matmul(A, B)
+ torch.cuda.synchronize()
+ print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
+ C32A, SA = F.transform(CA, 'col32')
+ CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
+ CxB, SB = F.transform(CB, to_order=formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
+ torch.cuda.synchronize()
+ print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ BA, statsB = F.vectorwise_quant(B, dim=1)
+ CxB, SB = F.nvidia_transform(CB, to_order=formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ A2 = A.view(-1, A.shape[-1]).contiguous()
+ CA, statsA = F.vectorwise_quant(A2, dim=1)
+ C32A, SA = F.nvidia_transform(CA, 'col32')
+ out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
+ Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
+ F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
+ torch.cuda.synchronize()
+ print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear')
+ CxB, SB = F.nvidia_transform(CB, to_order=formatB)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ A2 = A.view(-1, A.shape[-1]).contiguous()
+ CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear')
+ C32A, SA = F.nvidia_transform(CA, 'col32')
+ out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
+ Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32)
+ out = Cout*statsB*statsA*(1.0/(127*127))
+ torch.cuda.synchronize()
+ print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+ linear8bit(A)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ linear8bit(A)
+ torch.cuda.synchronize()
+ print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+
+ linearMixedBit(A)
+ torch.cuda.synchronize()
+ t0 = time.time()
+ for i in range(100):
+ linearMixedBit(A)
+ torch.cuda.synchronize()
+ print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s')
+
+
+def test_zeropoint():
+ def min_max(x):
+ maxA = torch.amax(x, dim=1, keepdim=True)
+ minA = torch.amin(x, dim=1, keepdim=True)
+ midpoint = (maxA-minA)/2.0
+ dyna = 252/(maxA-minA)
+ #dyna *= 0.98
+ x = dyna*x
+ x = x - torch.round((dyna*(minA+midpoint)))
+ return x.to(torch.int8), minA, midpoint, dyna
+ batch = 2
+ seq = 2
+ model = 4
+ hidden = 2*model
+ #batch = 4
+ #seq = 2048
+ #model = 1024
+ #hidden = 8*model
+ A = torch.randn(batch*seq, model, device='cuda').half()-0.4
+ B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half())
+
+ #A[0] = 0
+ #B[:, 0] = 0
+ #A = A*(A>0)
+ #A[0, 0] = 0
+ #A[0, 0] = 6.0
+
+ Ac, minA, midpoint, dyna = min_max(A)
+ #print(Ac[0, 0], 'zero')
+ #print(Ac, Ac.min(), Ac.max())
+ Bc, maxB = F.vectorwise_quant(B, quant_type='linear')
+ out = F.igemm(Ac, Bc)
+ out2 = torch.matmul(A,B)
+ offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna
+ out = out.float()
+ #print(out.shape, maxB.shape, scale.shape, offset.shape)
+ norm1 = maxB/127
+ C4 = (out/dyna)*norm1+offset
+
+
+ B1 = torch.nn.Parameter(B.clone())
+ B2 = torch.nn.Parameter(B.clone())
+ B3 = torch.nn.Parameter(B.clone())
+ B4 = torch.nn.Parameter(B.clone())
+
+
+ C1 = torch.matmul(A, B1)
+ C2 = bnb.matmul_cublas(A, B2, None, 'linear')
+ C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint')
+ C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint')
+
+ err1 = torch.abs(C1-C2).mean().item()
+ err2 = torch.abs(C1-C3).mean().item()
+ err3 = torch.abs(C1-C4).mean().item()
+ print(err1, err2, err3)
+ #assert err1 > err2
+
+ loss1 = C1.mean()
+ loss2 = C2.mean()
+ loss3 = C3.mean()
+ loss4 = C4.mean()
+
+ loss1.backward()
+ loss2.backward()
+ loss3.backward()
+ loss4.backward()
+
+ print(B.grad)
+ print(B1.grad)
+ print(B2.grad)
+ print(B3.grad)
+ print(B4.grad)
+ err1 = torch.abs(B1.grad-B2.grad).mean().item()
+ err2 = torch.abs(B1.grad-B3.grad).mean().item()
+ err3 = torch.abs(B1.grad-B4.grad).mean().item()
+ print(err1, err2, err3)
+
+
+
+
+def test_zp():
+ def quant_zp(x):
+ dtype = x.dtype
+ x = x.float()
+ dyna = x.max() - x.min()
+ if dyna == 0: dyna = 1
+ qx = 254./dyna
+ minx = x.min()
+ #zpx = torch.round(minx* qx)
+ #zpx = 127 - torch.round(x.max()* qx)
+ zpx = torch.round(x.min()* qx) - 127
+ x = (qx*x) + zpx
+ return x, qx, zpx
+ batch = 2
+ seq = 512
+ model = 1024
+ hidden = 4*model
+ A = torch.randn(batch*seq, model, device='cuda').half()*0.1
+ B = torch.randn(model, hidden, device='cuda').half()*0.1
+
+
+ C0 = torch.matmul(A, B)
+
+
+ #A, SA = F.vectorwise_quant(A, quant_type='linear')
+ #B, SB = F.vectorwise_quant(B, quant_type='linear')
+ A = A.float()
+ B = B.float()
+
+ C1 = torch.matmul(A, B)
+ C3 = bnb.matmul(A.half(), B.t().contiguous().half())
+
+ zp = 1
+ #C2 = torch.matmul(A-zp, B)
+ #C2 += B.sum(0).view(1, -1)*zp
+ C2 = torch.matmul(A, B-zp)
+ C2 -= A.sum(1).view(-1, 1)*zp
+
+ ca, cqa, cza = quant_zp(A)
+ print(ca.min(), ca.max())
+ print((ca-cza).min(), (ca-cza).max())
+
+ zp = 1
+ scale = 2.0
+ C5 = torch.matmul((A*scale)-zp, B)
+ C5 += B.sum(0)*zp
+ C5 /= scale
+
+ CA, qa, zpa = quant_zp(A)
+ C4 = torch.matmul(CA, B)
+ C4 -= B.sum(0)*zpa
+ C4 /= qa
+ zpb = 1
+ zpa = 1
+ qa = 2
+ qb = 2
+ C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb)
+ C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
+ C6 -= zpa*zpb*A.shape[1]
+ C6 /= qa*qb
-def test_histogram():
- dim1, dim2 = 32, 32
- source = torch.rand(dim1, dim2, device='cuda')
- idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
- idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
- histogram1 = torch.zeros((256, 256)).cuda()
- histogram2 = torch.zeros((256, 256)).cuda()
+ CA, qa, zpa = quant_zp(A)
+ CB, qb, zpb = quant_zp(B)
+ C7 = torch.matmul(CA, CB)
+ C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb)
+ C7 -= zpa*zpb*A.shape[1]
+ C7 /= qa*qb
- F.histogram_scatter_add_2d(histogram2, idx1, idx2, source)
+ print('')
+ #print(C0.flatten()[:10])
+ print(C1.flatten()[:10])
+ print(C2.flatten()[:10])
+ print(C3.flatten()[:10])
+ print(C5.flatten()[:10])
+ print(C6.flatten()[:10])
+ print(C7.flatten()[:10])
+ err1 = torch.abs(C1-C2).mean().item()
+ err2 = torch.abs(C1-C3).mean().item()
+ err3 = torch.abs(C1-C4).mean().item()
+ err4 = torch.abs(C1-C5).mean().item()
+ err5 = torch.abs(C1-C6).mean().item()
+ err6 = torch.abs(C1-C7).mean().item()
+ print(err1, err2, err3, err4, err5, err6)
- for i in range(dim1):
- for j in range(dim2):
- histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j]
- torch.testing.assert_allclose(histogram1, histogram2)
- torch.testing.assert_allclose(histogram1.sum(), source.sum())
diff --git a/tests/test_modules.py b/tests/test_modules.py
index a0379cb..a2c950b 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -1,42 +1,470 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import pytest
import torch
+
+from itertools import product
+from torch import nn
+
import bitsandbytes as bnb
+class MockArgs(object):
+ def __init__(self, initial_data):
+ for key in initial_data:
+ setattr(self, key, initial_data[key])
+
+class MLP8bit(torch.nn.Module):
+ def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
+ super(MLP8bit, self).__init__()
+ self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
+ self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x
+
+
+def get_args():
+ args = MockArgs([])
+ args.quant_type = 'vector'
+ args.use_8bit_training = 'full'
+ args.clip_freq = 9999
+ return args
+
+def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
+ idx = torch.isclose(a, b, rtol, atol)
+ sumval = (idx==0).sum().item()
+ if sumval > count:
+ print(f'Too many values not close: assert {sumval} < {count}')
+ torch.testing.assert_allclose(a, b, rtol, atol)
+
+class LinearFunction(torch.autograd.Function):
+
+ @staticmethod
+ def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ norm = math.sqrt(math.pi)/math.sqrt(2.0)
+ #std = torch.abs(x).mean()*norm
+ std = torch.std(x)
+ max1 = std*trim_value
+ x = x/max1*127
+ x = round_func(x)
+ x[x > 127] = 127
+ x[x < -127] = -127
+ x = x/127*max1
+
+ return x
+
+ def quant(x, quant_type, dim=1):
+ if quant_type == 'linear':
+ max1 = torch.abs(x).max().float()
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'vector':
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'min-max':
+ maxA = torch.amax(x, dim=dim, keepdim=True).float()
+ minA = torch.amin(x, dim=dim, keepdim=True).float()
+ scale = (maxA-minA)/2.0
+ xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
+ return xq, (minA.float(), scale.float())
+ else: return None
+
+ def dequant(xq, S1, S2, dtype, quant_type):
+ if quant_type == 'linear':
+ norm = S1*S2/(127*127)
+ # double cast needed to prevent overflows
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'vector':
+ x = xq.float()
+ if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
+ if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
+ #print(x.shape, S1.shape, S2.shape)
+ if len(S1.shape) == 2:
+ x *= S1.t()/127
+ else:
+ x *= S1/127
+ x *= S2/127
+ return x.to(dtype)
+ else: return None
+
+ def dequant_min_max(xq, A, B, SA, SB, dtype):
+ offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ x = xq.float()
+ if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
+ if len(SB.shape) == 2:
+ x *= SB.t()/127
+ else:
+ x *= SB/127
+ x *= SA[1]/127
+ x +=offset
+ return x.to(dtype)
+
+
+ def get_8bit_linear(x, stochastic=False):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ max1 = torch.abs(x).max()
+ x = x/max1*127
+ x = round_func(x)/127*max1
+ #x = torch.round(x)/128*max1
+ return x
+
+ @staticmethod
+ def get_8bit_vector_wise(x, dim, stochastic=False):
+ round_func = LinearFunction.round_stoachastic if stochastic else torch.round
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ max1[max1==0] = 1.0
+ x = (x*127)/max1
+ x = round_func(x)/127*max1
+ return x
+
+ @staticmethod
+ def round_stoachastic(x):
+ sign = torch.sign(x)
+ absx = torch.abs(x)
+ decimal = absx-torch.floor(absx)
+ rdm = torch.rand_like(decimal)
+ return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
+
+ @staticmethod
+ def fake_8bit_storage(w, exponent_bits):
+ code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
+ absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
+ out = bnb.functional.dequantize_blockwise(absmax, C, code)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_quantile(w, args):
+ code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
+ #C = bnb.functional.quantize_no_absmax(code, w)
+ #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
+ #print(out)
+ #out = out.half()
+ code /= torch.max(torch.abs(code))
+ absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
+ out = bnb.functional.dequantize_blockwise(absmax, C, code)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_stoachstic(w):
+ rand = torch.rand(1024, device=w.device)
+ absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
+ out = bnb.functional.dequantize_blockwise(absmax, C)
+ out = out.half()
+ w.copy_(out)
+ return out
+
+ @staticmethod
+ def fake_8bit_storage_with_max(w, topk=8):
+ blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
+ max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
+ idx = idx[:, :topk]
+ max_val = max_val[:, :topk]
+
+ mask = torch.zeros_like(blocked_w)
+ mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
+ mask = mask.bool()
+
+ # 1. zero out max values
+ # 2. quantize + dequantize
+ # 3. write back max values
+ # 4. copy matrix back to weight
+
+ values = blocked_w[mask]
+ blocked_w[mask] = 0
+
+ code = bnb.functional.create_dynamic_map()
+ code = code.to(w.device)
+ absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
+ bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
+
+ blocked_w[mask] = values
+
+ unblocked_w = blocked_w.flatten().view(w.shape)
+
+ w.copy_(unblocked_w)
+ return unblocked_w
+
+
+ @staticmethod
+ def forward(ctx, x, weight, bias=None, args=None):
+ if args.use_8bit_training != 'off':
+ weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
+ outputq = bnb.functional.igemm(x8, weight8.t())
+ output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
+ #if torch.rand(1) < 0.01:
+ #output32 = torch.matmul(x, weight.t())
+ #err = torch.abs(output-output32).float()
+ #relerr = err/(torch.abs(output32).float()+1e-8)
+ #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
+ else:
+ #output = torch.matmul(x, weight.t())
+ output = torch.einsum('bsi,oi->bso', x, weight)
+
+ ctx.save_for_backward(x, weight, bias)
+ ctx.args = args
+
+ if bias is not None:
+ output += bias.unsqueeze(0).expand_as(output)
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, weight, bias = ctx.saved_tensors
+ args = ctx.args
+ stochastic = False
+ grad_input = grad_weight = grad_bias = None
+ if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
+
+ # weight and x are already 8bit
+ # -> transform grad_output to 8-bit
+ if args.use_8bit_training == 'forward+wgrad':
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
+ grad_weight8 = bnb.functional.igemm(grad_output8, x8)
+ grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+
+ #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
+
+ grad_input = grad_output.matmul(weight)
+ elif args.use_8bit_training == 'full':
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
+ x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
+ grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
+ bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
+ grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
+
+ grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
+ weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
+ grad_input8 = bnb.functional.igemm(grad_output8, weight8)
+ grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
+
+ else:
+ grad_input = grad_output.matmul(weight)
+ grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
-@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
-def test_embeddings(embcls):
- bnb.optim.GlobalOptimManager.get_instance().initialize()
- emb1 = torch.nn.Embedding(100, 512).cuda()
- emb2 = embcls(100, 512).cuda()
+ return grad_input, grad_weight, grad_bias, None
- adam1 = bnb.optim.Adam8bit(emb1.parameters())
- adam2 = bnb.optim.Adam8bit(emb2.parameters())
+class Linear8bit(nn.Module):
+ def __init__(self, input_features, output_features, bias=True, args=None):
+ super(Linear8bit, self).__init__()
+ self.input_features = input_features
+ self.output_features = output_features
+ self.args = args
- batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
+ self.weight = nn.Parameter(torch.empty(output_features, input_features))
+ if bias:
+ self.bias = nn.Parameter(torch.empty(output_features))
+ else:
+ self.register_parameter('bias', None)
+ torch.nn.init.xavier_uniform_(self.weight)
+ if self.bias is not None:
+ torch.nn.init.zeros_(self.bias)
+
+ def forward(self, x):
+ self.args.training = self.training
+
+ return LinearFunction.apply(x, self.weight, self.bias, self.args)
+
+
+
+def test_linear8bit():
+ l0 = torch.nn.Linear(32, 64).cuda().half()
+ l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
+ l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
+ l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
+
+ l0.weight.data = l2.weight.data.clone()
+ l0.bias.data = l2.bias.data.clone()
+
+ l1.weight.data = l2.weight.data.clone()
+ l1.bias.data = l2.bias.data.clone()
+
+ l3.weight.data = l2.weight.data.clone()
+ l3.bias.data = l2.bias.data.clone()
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ t = torch.randn(16, 8, 64, device='cuda').half()
+ b2 = b1.clone()
+ b3 = b1.clone()
+ b0 = b1.clone()
+
+ o0 = l0(b0)
+ o1 = l1(b1)
+ o2 = l2(b2)
+ o3 = l3(b3)
+
+ assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
+ assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)
+
+ loss0 = torch.nn.functional.mse_loss(o0, t)
+ loss1 = torch.nn.functional.mse_loss(o1, t)
+ loss2 = torch.nn.functional.mse_loss(o2, t)
+ loss3 = torch.nn.functional.mse_loss(o3, t)
+
+ loss0.backward()
+ loss1.backward()
+ loss2.backward()
+ loss3.backward()
+
+ assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
+ assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+ assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
+
+ err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
+ err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
+ err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
+
+ assert err1*0.8 < err2
+ assert err2*0.8 < err3
+ assert err3*0.8 < err1
+
+ l0.weight.grad = None
+ l1.weight.grad = None
+ l2.weight.grad = None
+ l3.weight.grad = None
+ l0.bias.grad = None
+ l1.bias.grad = None
+ l2.bias.grad = None
+ l3.bias.grad = None
+
+
+threshold = [0.0, 3.0]
+values = threshold
+names = ['threshold_{0}'.format(vals) for vals in values]
+@pytest.mark.parametrize("threshold", values, ids=names)
+def test_linear8bitlt_inference(threshold):
+ l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
+ assert l1.weight.device.type == 'cuda'
+ assert l1.weight.dtype == torch.float16
+
+ l1.eval()
for i in range(100):
- batch = batches[i]
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ if i == 1:
+ assert l1.state.CxB is not None
+
+def test_linear8bitlt_accumulated_gradient():
+ l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
+ l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
+ l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
+ l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
+ l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
+ l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
+ opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
+ opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
+
+ acc_steps = 10
+
- embedded1 = emb1(batch)
- embedded2 = emb2(batch)
+ for i in range(10):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ o2 = l2(b1)
+ loss1 = o1.mean()
+ loss2 = o2.mean()
+ loss1.backward()
+ loss2.backward()
+ if i == 2:
+ assert l1[0].state.CxB is not None
+ assert l1[1].state.CxB is not None
- l1 = embedded1.mean()
- l2 = embedded2.mean()
+ if i > 0 and i % acc_steps == 0:
+ opt1.step()
+ opt1.zero_grad(True)
+ opt2.step()
+ opt2.zero_grad(True)
+ assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
+ assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
+ # we do this copy because otherwise we have small divergences over time that add up
+ l1[0].weight.data.copy_(l2[0].weight.data)
+ l1[1].weight.data.copy_(l2[1].weight.data)
+ else:
+ torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
+ torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
- l1.backward()
- l2.backward()
- adam1.step()
- adam2.step()
+threshold = [0.0, 2.0]
+values = threshold
+names = ['threshold_{0}'.format(vals) for vals in values]
+@pytest.mark.parametrize("threshold", values, ids=names)
+def test_linear8bitlt_no_fp16_weights(threshold):
+ l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ assert l1.weight.dtype == torch.int8
- adam1.zero_grad()
- adam2.zero_grad()
+ l1.eval()
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = l1(b1)
+ assert o1.dtype == torch.float16
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
- assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
- assert adam2.state[emb2.weight]['state1'].dtype == torch.float32
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+ assert mlp.fc1.weight.device.type == 'cuda'
+ assert mlp.fc2.weight.device.type == 'cuda'
+
+ mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
+
+ for i in range(100):
+ b1 = torch.randn(16, 8, 32, device='cuda').half()
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ if threshold > 0: assert mlp.fc1.state.idx is not None
+ if threshold > 0: assert mlp.fc2.state.idx is not None
+ assert mlp.fc1.weight.dtype == torch.int8
+ assert mlp.fc2.weight.dtype == torch.int8
+ assert mlp.fc1.weight.device.type == 'cuda'
+ assert mlp.fc2.weight.device.type == 'cuda'
diff --git a/tests/test_optim.py b/tests/test_optim.py
index c80fe51..b173eaa 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -1,12 +1,9 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
import os
import time
import shutil
import uuid
import pytest
+import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
@@ -14,7 +11,9 @@ import bitsandbytes.functional as F
from os.path import join
from itertools import product
-import apex
+#import apex
+
+k = 20
def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
@@ -26,55 +25,47 @@ def rm_path(path):
str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
-str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
-str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
+#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
+#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
-str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
-str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
+#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
+#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
-str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW)
-str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
+#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
-str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
+#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False))
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
-str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
+#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
-str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
-str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
-str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['rmsprop'] = [('square_avg', 'state1')]
-str2statenames['adagrad'] = [('sum', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
-str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
-str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
+optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32:
- atol, rtol = 2e-6, 1e-5
+ atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 1e-4, 1e-3
- for i in range(50):
+ for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
- if i % 10 == 0 and i > 0:
+ if i % (k//5) == 0 and i > 0:
path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
del bnb_optimizer
@@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize()
- bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
@@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
else:
atol, rtol = 1e-4, 1e-3
- original_p2 = p2[mask].clone()
-
for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
@@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
p2.grad = g2
p3.grad = g3
- if i > 30 and i % 10 == 0:
- g1.data[mask] = 0.0
- g2.data[mask] = 0.0
- p1.grad = g1
- p2.grad = g2
- original_p1 = p1[mask].clone()
- original_p2 = p2[mask].clone()
- og_s1 = adam2.state[p2]['state1'][mask].clone()
- og_s2 = adam2.state[p2]['state2'][mask].clone()
- og_s11 = adam2.state[p1]['state1'][mask].clone()
- og_s21 = adam2.state[p1]['state2'][mask].clone()
-
adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8
- if i > 30 and i % 10 == 0:
- torch.testing.assert_allclose(original_p2, p2[mask])
- torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
- torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
- assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
- assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
- assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
-
-
dim1 = [1024]
dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
-optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
+optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
-
bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g
- for i in range(5000):
- if i == 500:
+ for i in range(k):
+ if i == k//5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
@@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize()
s = time.time()-t0
print('')
- params = 4500*4096*4096
+ params = (k-k//5)*dim1*dim2
print(optim_name, gtype, s/params)
#assert s < 3.9
-
-def test_str_betas():
- betas = (0.80, 0.95)
- strbetas = '(0.80, 0.95)'
-
- layer = torch.nn.Linear(10, 10)
-
- base = bnb.optim.Adam(layer.parameters(), betas=betas)
- strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
- assert base.defaults['betas'][0] == 0.8
- assert base.defaults['betas'][1] == 0.95
- assert strbase.defaults['betas'][0] == 0.8
- assert strbase.defaults['betas'][1] == 0.95
-
-