summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/__init__.py3
-rw-r--r--bitsandbytes/autograd/__init__.py0
-rw-r--r--bitsandbytes/autograd/_functions.py307
-rw-r--r--bitsandbytes/cextension.py2
-rw-r--r--bitsandbytes/functional.py869
-rw-r--r--bitsandbytes/nn/__init__.py2
-rw-r--r--bitsandbytes/nn/modules.py124
7 files changed, 1301 insertions, 6 deletions
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 02ca804..3c3affa 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.
from .nn import modules
+from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
from .cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .optim import adam
-__pdoc__ = {'libBitsNBytes': False,
+__pdoc__ = {'libbitsandbytes': False,
'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False
}
diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/bitsandbytes/autograd/__init__.py
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
new file mode 100644
index 0000000..815a4f1
--- /dev/null
+++ b/bitsandbytes/autograd/_functions.py
@@ -0,0 +1,307 @@
+import torch
+import bitsandbytes as bnb
+import bitsandbytes.functional as F
+
+from dataclasses import dataclass
+
+tensor = torch.Tensor
+
+'''
+ This class pools outlier dimensions across layers.
+ This is particularly important for small models where outlier features
+ are less systematic and occur with low frequency.
+'''
+class GlobalOutlierPooler(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.outliers = set()
+ self.model_dim = None
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+ def add_outliers(self, outlier_idx, feature_dim):
+ if self.model_dim is None: self.model_dim = feature_dim
+ if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
+
+ self.outliers.update(outlier_idx.tolist())
+
+ def get_current_outlier_idx(self):
+ return torch.Tensor(list(self.outliers)).to(torch.int64)
+
+class MatMul8bit(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
+
+ if precision[0] != 8:
+ with torch.no_grad():
+ output = torch.matmul(A, B)
+ else:
+ if len(B.shape) == 2: dim = 0
+ else: dim = 1
+ qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
+ qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
+ iout = F.igemm(qA, qB)
+ output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
+
+ if A.requires_grad or B.requires_grad:
+ ctx.save_for_backward(A, B)
+
+ ctx.quant_type = quant_type
+ ctx.precision = precision
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ A, B = ctx.saved_tensors
+ quant_type = ctx.quant_type
+ precision = ctx.precision
+ grad_A = grad_B = None
+
+ if B.requires_grad:
+ if len(A.shape) == 3:
+ dims = [0, 1]
+ # bsi -> ibs
+ permute_dim = [0, 2, 1]
+ else:
+ dims = [0]
+ # bs -> sb
+ permute_dim = [1, 0]
+
+ if precision[1] != 8:
+ with torch.no_grad():
+ grad_B = torch.matmul(A.permute(permute_dim), grad_output)
+ else:
+ if len(B.shape) == 2 and len(A.shape) == 3:
+ grad_output = grad_output.contiguous()
+ if not grad_output.is_contiguous(): grad_output.contiguous()
+ qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
+ if not A.is_contiguous(): A = A.contiguous()
+ qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
+ igrad_B = F.igemm(qA.t(), qgrad_output)
+ grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
+ else:
+ qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
+ igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
+ grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
+
+ if A.requires_grad:
+ if len(grad_output.shape) == 3: dims = [2]
+ else: dims = [1]
+
+ if len(B.shape) == 3:
+ # bio -> boi
+ permute_dim = [0, 2, 1]
+ dim_B = dims
+ else:
+ # io -> oi
+ permute_dim = [1, 0]
+ dim_B = [1]
+
+ if precision[2] != 8:
+ with torch.no_grad():
+ grad_A = torch.matmul(grad_output, B.permute(permute_dim))
+ else:
+ qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
+ qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
+ igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
+ grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
+
+ return grad_A, grad_B, None, None, None
+
+
+mm_cublas = MatMul8bit.apply
+bmm_cublas = MatMul8bit.apply
+matmul_cublas = MatMul8bit.apply
+
+@dataclass
+class MatmulLtState:
+ CB = None
+ CxB = None
+ SB = None
+ SCB = None
+
+ CxBt = None
+ SBt = None
+ CBt = None
+
+ subB = None
+
+ outlier_pool = None
+ has_accumulated_gradients = False
+ threshold = 0.0
+ idx = None
+ is_training = True
+ has_fp16_weights = True
+ use_pool = False
+ formatB = F.get_special_format_str()
+
+ def reset_grads(self):
+ self.CB = None
+ self.CxB = None
+ self.SB = None
+ self.SCB = None
+
+ self.CxBt = None
+ self.SBt = None
+ self.CBt = None
+
+
+class MatMul8bitLt(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, A, B, out=None, state=MatmulLtState()):
+ # 1. Quantize A
+ # 2. Quantize B
+ # 3. Matmul
+ # 4. Mixed-precision decomposition matmul
+ # 5. Save state
+ requires_gradA = A.requires_grad
+ requires_gradB = B.requires_grad
+ formatB = state.formatB
+ input_shape = A.shape
+ if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
+ assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
+
+ # 1. Quantize A
+ if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
+ CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
+
+ if state.threshold > 0.0 and coo_tensorA is not None:
+ if state.has_fp16_weights:
+ idx = torch.unique(coo_tensorA.colidx).long()
+ CA[:, idx] = 0
+ CAt[:, idx] = 0
+ subA = A[:, idx]
+ state.subB = B[:, idx].t().contiguous()
+ state.idx = idx
+ else:
+ if state.CxB is None:
+ # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
+ # we also need to convert it to the turing/ampere format
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
+ # generate outlier index and subB
+ outlier_idx = torch.unique(coo_tensorA.colidx).long()
+ state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
+ if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
+ # do not use pool for 2nd FFN layer
+ state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
+ else:
+ state.idx = outlier_idx
+ state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
+
+ if state.idx is not None:
+ # extract outliers
+ CA[:, state.idx] = 0
+ CAt[:, state.idx] = 0
+ subA = A[:, state.idx]
+ else:
+ subA = None
+ else:
+ if not state.has_fp16_weights and state.CxB is None:
+ state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
+ subA = None
+
+ C32A, SA = F.transform(CA, 'col32')
+
+ # 2. Quantize B
+ if state.has_fp16_weights:
+ has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
+ is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
+ if is_transposed: B = B.contiguous()
+
+ if (state.is_training and not has_grad) or state.CxB is None:
+ state.reset_grads()
+ CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
+ state.CxB, state.SB = F.transform(CB, to_order=formatB)
+ else:
+ has_grad = False
+
+ shapeB = state.SB[0]
+
+ if len(input_shape) == 3:
+ output_shape = (input_shape[0], input_shape[1], shapeB[0])
+ else:
+ output_shape = (input_shape[0], shapeB[0])
+
+ # 3. Matmul
+ out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
+ output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
+
+ # 4. Mixed-precision decomposition matmul
+ if state.threshold > 0.0 and coo_tensorA is not None and subA is not None:
+ output += torch.matmul(subA, state.subB)
+
+ # 5. Save state
+ ctx.state = state
+
+ ctx.formatB = formatB
+ ctx.grad_shape = input_shape
+ ctx.req_grads = [requires_gradA, requires_gradB]
+
+ if requires_gradA or requires_gradB:
+ ctx.tensors = (CAt, subA)
+ ctx.tensor_states = (SCAt, state.idx)
+ else:
+ ctx.tensors = [None, None]
+ ctx.tensor_states = (None, None)
+ ctx.save_for_backward(None, None)
+
+ #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
+ clone_func = torch.clone
+ return clone_func(output.view(output_shape))
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ req_gradA, req_gradB = ctx.req_grads
+ CAt, subA = ctx.tensors
+ SCAt, idx = ctx.tensor_states
+ formatB = ctx.formatB
+ state = ctx.state
+ assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
+
+ if len(grad_output.shape) == 3:
+ grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
+
+ grad_A = grad_B = None
+
+ Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
+ if req_gradB:
+ CxAt, SAt = F.transform(CAt, formatB, transpose=True)
+ C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
+ gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
+ grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
+ if state.threshold > 0.0 and subA is not None:
+ grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
+
+ if req_gradA:
+ C32grad, Sgrad = F.transform(Cgrad, 'col32')
+ if state.CxBt is None:
+ state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
+ gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
+ grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
+
+ return grad_A, grad_B, None, None, None, None, None
+
+
+matmul = MatMul8bitLt.apply
+
+
+def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
+ state = state or MatmulLtState()
+ if threshold > 0.0:
+ state.threshold = threshold
+ return MatMul8bitLt.apply(A, B, out, state)
+
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index 63d627e..2374c35 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try:
lib.cadam32bit_g32
+ lib.get_context.restype = ct.c_void_p
+ lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index ab4e565..806c254 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -36,9 +36,51 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16)
str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16)
-optimal_normal = [-0.9939730167388916, -0.8727636337280273, -0.8097418546676636, -0.7660024166107178, -0.7318882346153259, -0.6793879270553589, -0.657649040222168, -0.6385974884033203, -0.6211113333702087, -0.5901028513908386, -0.5762918591499329, -0.5630806684494019, -0.5509274005889893, -0.5394591689109802, -0.5283197164535522, -0.517780065536499, -0.5074946284294128, -0.4980469048023224, -0.48867011070251465, -0.48003149032592773, -0.47125306725502014, -0.4629971981048584, -0.4547359049320221, -0.446626216173172, -0.43902668356895447, -0.43158355355262756, -0.4244747757911682, -0.4173796474933624, -0.41038978099823, -0.4055633544921875, -0.4035947024822235, -0.39701032638549805, -0.39057496190071106, -0.38439232110977173, -0.3782760500907898, -0.3721940815448761, -0.3661896586418152, -0.3604033589363098, -0.354605108499527, -0.34892538189888, -0.34320303797721863, -0.3376772701740265, -0.3323028087615967, -0.3269782066345215, -0.32166096568107605, -0.316457599401474, -0.3112771809101105, -0.3061025142669678, -0.30106794834136963, -0.2961243987083435, -0.2912728488445282, -0.28644347190856934, -0.28165507316589355, -0.2769731283187866, -0.2722635865211487, -0.26779335737228394, -0.26314786076545715, -0.2586647868156433, -0.2541804611682892, -0.2496625930070877, -0.24527113139629364, -0.24097171425819397, -0.23659978806972504, -0.23218469321727753, -0.22799566388130188, -0.22380566596984863, -0.21965542435646057, -0.2154538631439209, -0.2113603949546814, -0.20735277235507965, -0.20334717631340027, -0.19932441413402557, -0.19530178606510162, -0.19136647880077362, -0.18736697733402252, -0.18337111175060272, -0.17951400578022003, -0.1757056713104248, -0.17182783782482147, -0.1680615097284317, -0.16431649029254913, -0.16053077578544617, -0.15685945749282837, -0.15298527479171753, -0.1493264138698578, -0.14566898345947266, -0.14188314974308014, -0.13819937407970428, -0.1344561129808426, -0.1306886374950409, -0.1271020770072937, -0.12346585839986801, -0.11981867253780365, -0.11614970862865448, -0.11256207525730133, -0.10889036953449249, -0.10525048524141312, -0.1016591489315033, -0.09824034571647644, -0.09469068050384521, -0.0911419615149498, -0.08773849159479141, -0.08416644483804703, -0.08071305602788925, -0.07720902562141418, -0.07371306419372559, -0.07019119709730148, -0.06673648208379745, -0.06329209357500076, -0.059800852090120316, -0.0564190037548542, -0.05296570807695389, -0.049522045999765396, -0.04609023034572601, -0.04262964054942131, -0.039246633648872375, -0.03577171266078949, -0.03236335143446922, -0.028855687007308006, -0.02542758360505104, -0.022069433704018593, -0.018754752352833748, -0.015386369079351425, -0.01194947212934494, -0.008439815603196621, -0.004995611496269703, -0.0016682245768606663, 0.0, 0.0015510577941313386, 0.005062474869191647, 0.008417150937020779, 0.011741090565919876, 0.015184164978563786, 0.018582714721560478, 0.02204744517803192, 0.025471193715929985, 0.02889077737927437, 0.0323684960603714, 0.03579240292310715, 0.039281025528907776, 0.0427563451230526, 0.04619763046503067, 0.04968220740556717, 0.05326594039797783, 0.05679265409708023, 0.060245808213949203, 0.06372645497322083, 0.06721872836351395, 0.0706876739859581, 0.0742349922657013, 0.07774098962545395, 0.08123527467250824, 0.08468879014253616, 0.08810535818338394, 0.09155989438295364, 0.09498448669910431, 0.0985206812620163, 0.10206405073404312, 0.10563778132200241, 0.10921968519687653, 0.11284469068050385, 0.11653254181146622, 0.12008969485759735, 0.12368203699588776, 0.1272617131471634, 0.13089501857757568, 0.134552001953125, 0.1382799744606018, 0.14194637537002563, 0.14563234150409698, 0.14930322766304016, 0.15303383767604828, 0.1567956507205963, 0.16050070524215698, 0.16431072354316711, 0.16813558340072632, 0.17204202711582184, 0.1758781224489212, 0.17973239719867706, 0.1836014688014984, 0.18753431737422943, 0.19138391315937042, 0.19535475969314575, 0.19931404292583466, 0.20333819091320038, 0.20738255977630615, 0.21152682602405548, 0.21568812429904938, 0.21978361904621124, 0.22393859922885895, 0.22814159095287323, 0.23241068422794342, 0.23675410449504852, 0.24123944342136383, 0.24569889903068542, 0.2500703036785126, 0.25904011726379395, 0.26349544525146484, 0.2682226300239563, 0.272907555103302, 0.2774306833744049, 0.28220856189727783, 0.2869136929512024, 0.2916390895843506, 0.29649388790130615, 0.30142995715141296, 0.3065022826194763, 0.3114383816719055, 0.31648796796798706, 0.3216581642627716, 0.32700115442276, 0.3322487473487854, 0.33778008818626404, 0.3431521952152252, 0.3487405776977539, 0.3543166518211365, 0.3601346015930176, 0.36605337262153625, 0.37217751145362854, 0.378179669380188, 0.3843980133533478, 0.3906566798686981, 0.39714935421943665, 0.40357843041419983, 0.4104187488555908, 0.4171563684940338, 0.42418959736824036, 0.43136918544769287, 0.4389212429523468, 0.44673123955726624, 0.45457619428634644, 0.4627031683921814, 0.47130417823791504, 0.4798591434955597, 0.48897242546081543, 0.4979848861694336, 0.5, 0.5076631307601929, 0.5177803635597229, 0.5282770991325378, 0.5392990112304688, 0.5506287813186646, 0.5632893443107605, 0.5764452815055847, 0.5903191566467285, 0.6051878333091736, 0.6209936141967773, 0.6382884979248047, 0.6573970913887024, 0.6795773506164551, 0.7037051916122437, 0.7327037453651428, 0.7677436470985413, 0.8111193776130676, 0.875165581703186, 1.0]
-optimal_half_normal = [0.0025565922260284424, 0.005811259150505066, 0.00961565226316452, 0.010822802782058716, 0.013123787939548492, 0.014242202043533325, 0.0143156498670578, 0.016469404101371765, 0.017666727304458618, 0.01773911714553833, 0.0199756920337677, 0.0210941880941391, 0.021161124110221863, 0.02451971173286438, 0.024580076336860657, 0.02685210108757019, 0.028012827038764954, 0.030198264867067337, 0.0302925705909729, 0.03136435151100159, 0.03374280035495758, 0.03487399220466614, 0.035243816673755646, 0.037192340940237045, 0.03822284936904907, 0.04164902865886688, 0.04173608124256134, 0.04401407018303871, 0.04508155584335327, 0.047482021152973175, 0.04756556823849678, 0.050963032990694046, 0.05196474492549896, 0.055417388677597046, 0.05793146416544914, 0.05799369141459465, 0.05887940526008606, 0.05895659327507019, 0.062420234084129333, 0.06493274495005608, 0.06499008461833, 0.06935599446296692, 0.07197384163737297, 0.07201516255736351, 0.07276943325996399, 0.07283210754394531, 0.07550075277686119, 0.07975354790687561, 0.07980883121490479, 0.08257630094885826, 0.0867777168750763, 0.08682405948638916, 0.08967285975813866, 0.09323835000395775, 0.09386616945266724, 0.09735457599163055, 0.09739077091217041, 0.10092401504516602, 0.10444298386573792, 0.10447832942008972, 0.10770941898226738, 0.10803905129432678, 0.11161200702190399, 0.1151546835899353, 0.11520349979400635, 0.11875157058238983, 0.11879390478134155, 0.1222602017223835, 0.122351735830307, 0.12240418791770935, 0.12594850733876228, 0.12597402930259705, 0.12602100148797035, 0.12960633635520935, 0.1296597123146057, 0.12966342642903328, 0.13227657973766327, 0.13325360417366028, 0.1333133578300476, 0.13691483438014984, 0.1371927298605442, 0.14066261053085327, 0.14088113978505135, 0.1447291411459446, 0.14805573225021362, 0.148526418954134, 0.15170684456825256, 0.15178103744983673, 0.15225710347294807, 0.1554398238658905, 0.15609459951519966, 0.15618794038891792, 0.1592724472284317, 0.1629735231399536, 0.16382690146565437, 0.16676269471645355, 0.16873238794505596, 0.17066434025764465, 0.17068277299404144, 0.1717144437134266, 0.17558929696679115, 0.17827065289020538, 0.17835864424705505, 0.18222273886203766, 0.18353315070271492, 0.18604370951652527, 0.18611834943294525, 0.1876586265861988, 0.18996606767177582, 0.19170701876282692, 0.19398853182792664, 0.19786442816257477, 0.19795633852481842, 0.20195159316062927, 0.2058800607919693, 0.2099103182554245, 0.2122517265379429, 0.21410366892814636, 0.21819619834423065, 0.22221362590789795, 0.22233009338378906, 0.22500130906701088, 0.2251257635653019, 0.22638091444969177, 0.23067741096019745, 0.23368822410702705, 0.2348879873752594, 0.2382080741226673, 0.2390350103378296, 0.2391497790813446, 0.24253453686833382, 0.24265171959996223, 0.2470107562839985, 0.24764248728752136, 0.24777774512767792, 0.2516774423420429, 0.256104726344347, 0.2564055472612381, 0.2607169933617115, 0.265461727976799, 0.26985861361026764, 0.2701106257736683, 0.2702729292213917, 0.274574413895607, 0.2750340588390827, 0.27919672429561615, 0.283704474568367, 0.28386808931827545, 0.28953738883137703, 0.2896753139793873, 0.29320384562015533, 0.29451676085591316, 0.295327290892601, 0.29802779853343964, 0.29818175733089447, 0.29972871020436287, 0.30290623009204865, 0.30305664241313934, 0.30486901476979256, 0.31299956142902374, 0.31518544629216194, 0.31790371239185333, 0.3205283172428608, 0.3230419009923935, 0.32595496252179146, 0.32612212374806404, 0.3282426446676254, 0.3283906430006027, 0.33146094158291817, 0.3316439874470234, 0.33365286886692047, 0.33723779395222664, 0.3390095978975296, 0.3427443392574787, 0.34853987768292427, 0.34869300201535225, 0.35457711294293404, 0.35537679493427277, 0.3604113645851612, 0.36124424636363983, 0.3665340431034565, 0.36667295172810555, 0.3727492541074753, 0.3729033060371876, 0.37888188660144806, 0.37907837703824043, 0.3792510814964771, 0.38557394221425056, 0.38573457673192024, 0.39108292758464813, 0.39911722019314766, 0.40589402988553047, 0.40604450181126595, 0.410498782992363, 0.4106704741716385, 0.4129834659397602, 0.4131447561085224, 0.4172855168581009, 0.4202354736626148, 0.4204071946442127, 0.43538858368992805, 0.4355536885559559, 0.4432900734245777, 0.44603554904460907, 0.4461968094110489, 0.451409537345171, 0.4598204083740711, 0.46002377942204475, 0.46178819239139557, 0.46868549659848213, 0.46995367109775543, 0.4868385046720505, 0.48702501133084297, 0.4958047419786453, 0.4960057884454727, 0.5051481872797012, 0.506847757846117, 0.5148334950208664, 0.5150565356016159, 0.5174009390175343, 0.5249751061201096, 0.5283288545906544, 0.5355450958013535, 0.539984006434679, 0.5467876642942429, 0.5522958822548389, 0.5584012717008591, 0.5706631988286972, 0.5836620181798935, 0.5836880058050156, 0.5942088551819324, 0.5975865572690964, 0.6102624125778675, 0.6124880760908127, 0.6286389082670212, 0.646102175116539, 0.6471664495766163, 0.665437325835228, 0.6687244363129139, 0.687017485499382, 0.6932839937508106, 0.7115348428487778, 0.7218200154602528, 0.7219699807465076, 0.7747527211904526, 0.7749756425619125, 0.8192005604505539, 0.8194110840559006, 0.8830635994672775, 0.9217727445065975, 0.9245667457580566, 0.947742685675621, 0.9674464613199234, 0.9890814647078514, 0.9891453236341476, 0.9925699159502983]
+class CUBLAS_Context(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.context = {}
+ #prev_device = torch.cuda.current_device()
+ #for i in range(torch.cuda.device_count()):
+ # torch.cuda.set_device(torch.device('cuda', i))
+ # self.context.append(ct.c_void_p(lib.get_context()))
+ #torch.cuda.set_device(prev_device)
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+ def get_context(self, device):
+ if device.index not in self.context:
+ prev_device = torch.cuda.current_device()
+ torch.cuda.set_device(device)
+ self.context[device.index] = ct.c_void_p(lib.get_context())
+ torch.cuda.set_device(prev_device)
+ return self.context[device.index]
+
+class Cusparse_Context(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.context = ct.c_void_p(lib.get_cusparse())
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
def create_linear_map(signed=True):
if signed:
@@ -89,6 +131,16 @@ def create_dynamic_map(signed=True, n=7):
data.sort()
return Tensor(data)
+def get_special_format_str():
+ major, minor = torch.cuda.get_device_capability()
+ if major < 7:
+ print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ assert major >= 7
+
+ if major == 7: return 'col_turing'
+ elif major == 8: return 'col_ampere'
+ else: return 'col_turing'
+
def get_ptr(A: Tensor) -> ct.c_void_p:
'''
Get the ctypes pointer from a PyTorch Tensor.
@@ -105,6 +157,105 @@ def get_ptr(A: Tensor) -> ct.c_void_p:
if A is None: return None
else: return ct.c_void_p(A.data.storage().data_ptr())
+def pre_call(device):
+ prev_device = torch.cuda.current_device()
+ torch.cuda.set_device(device)
+ return prev_device
+
+def post_call(prev_device):
+ torch.cuda.set_device(prev_device)
+
+def get_transform_func(dtype, orderA, orderOut, transpose=False):
+ name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
+ if not hasattr(lib, name):
+ print(name)
+ raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}')
+ else:
+ return getattr(lib, name)
+
+class GlobalData(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.data = {}
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+
+def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False):
+ #init_func = torch.empty
+ init_func = torch.zeros
+ dims = len(shape)
+
+ if dims == 2:
+ rows = shape[0]
+ elif dims == 3:
+ rows = shape[0]*shape[1]
+ cols = shape[-1]
+
+ state = (shape, to_order)
+ if transpose:
+ # swap dims
+ tmp = rows
+ rows = cols
+ cols = tmp
+ state = (shape[::-1], to_order)
+
+ if to_order == 'row' or to_order == 'col':
+ return init_func(shape, dtype=dtype, device=device), state
+ elif to_order == 'col32':
+ # blocks of 32 columns (padded)
+ cols = 32*((cols+31)//32)
+ return init_func((rows, cols), dtype=dtype, device=device), state
+ elif to_order == 'col_turing':
+ # blocks of 32 columns and 8 rows
+ cols = 32*((cols+31)//32)
+ rows = 8*((rows+7)//8)
+ return init_func((rows, cols), dtype=dtype, device=device), state
+ elif to_order == 'col_ampere':
+ # blocks of 32 columns and 32 rows
+ cols = 32*((cols+31)//32)
+ rows = 32*((rows+31)//32)
+ return init_func((rows, cols), dtype=dtype, device=device), state
+ else:
+ raise NotImplementedError(f'To_order not supported: {to_order}')
+
+def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+ if state is None: state = (A.shape, from_order)
+ else: from_order = state[1]
+ if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1])
+ else: new_state = (state[1], to_order)
+ func = get_transform_func(A.dtype, from_order, to_order, transpose)
+
+ shape = state[0]
+ if len(shape) == 2:
+ dim1 = ct.c_int32(shape[0])
+ dim2 = ct.c_int32(shape[1])
+ elif ld is not None:
+ n = math.prod(shape)
+ dim1 = math.prod([shape[i] for i in ld])
+ dim2 = ct.c_int32(n//dim1)
+ dim1 = ct.c_int32(dim1)
+ else:
+ dim1 = ct.c_int32(shape[0]*shape[1])
+ dim2 = ct.c_int32(shape[2])
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
+
+
+ return out, new_state
+
def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor:
'''
Estimates 256 equidistant quantiles on the input tensor eCDF.
@@ -544,3 +695,717 @@ def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor,
maxdim1 = ct.c_int32(histogram.shape[0])
n = ct.c_int32(index1.numel())
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
+
+def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
+ if not torch.cuda.is_initialized(): torch.cuda.init()
+ if A.dtype != expected_type or B.dtype != expected_type:
+ raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}')
+
+ sA = A.shape
+ sB = B.shape
+ tA = transposed_A
+ tB = transposed_B
+
+ correct = True
+
+ if len(sA) == 2 and len(sB) == 2:
+ if not tA and not tB and A.shape[1] != B.shape[0]: correct = False
+ elif tA and not tB and A.shape[0] != B.shape[0]: correct = False
+ elif tA and tB and A.shape[0] != B.shape[1]: correct = False
+ elif not tA and tB and A.shape[1] != B.shape[1]: correct = False
+ elif len(sA) == 3 and len(sB) == 2:
+ if not tA and not tB and A.shape[2] != B.shape[0]: correct = False
+ elif tA and not tB and A.shape[1] != B.shape[0]: correct = False
+ elif tA and tB and A.shape[1] != B.shape[1]: correct = False
+ elif not tA and tB and A.shape[2] != B.shape[1]: correct = False
+ elif len(sA) == 3 and len(sB) == 3:
+ if not tA and not tB and A.shape[2] != B.shape[1]: correct = False
+ elif tA and not tB and A.shape[1] != B.shape[1]: correct = False
+ elif tA and tB and A.shape[1] != B.shape[2]: correct = False
+ elif not tA and tB and A.shape[2] != B.shape[2]: correct = False
+
+ if out is not None:
+ sout = out.shape
+ # special case common in backprop
+ if not correct and len(sA) == 3 and len(sB) == 3:
+ if (sout[0] == sA[2] and sout[1] == sB[2] and
+ sA[0] == sB[0] and sA[1] == sB[1]):
+ correct = True
+ else:
+ if len(sA) == 2 and len(sB) == 2:
+ if not tA and not tB: sout = (sA[0], sB[1])
+ elif tA and tB: sout = (sA[1], sB[0])
+ elif tA and not tB: sout = (sA[1], sB[1])
+ elif not tA and tB: sout = (sA[0], sB[0])
+ elif len(sA) == 3 and len(sB) == 2:
+ if not tA and not tB: sout = (sA[0], sA[1], sB[1])
+ elif tA and tB: sout = (sA[0], sA[2], sB[0])
+ elif tA and not tB: sout = (sA[0], sA[2], sB[1])
+ elif not tA and tB: sout = (sA[0], sA[1], sB[0])
+ elif len(sA) == 3 and len(sB) == 3:
+ if not tA and not tB: sout = (sA[0], sA[1], sB[2])
+ elif tA and tB: sout = (sA[0], sA[2], sB[1])
+ elif tA and not tB: sout = (sA[0], sA[2], sB[2])
+ elif not tA and tB: sout = (sA[0], sA[1], sB[1])
+
+
+ if not correct:
+ raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.')
+
+ return sout
+
+def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+ sout = check_matmul(A, B, out, transposed_A, transposed_B)
+ if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+ if len(A.shape) == 3 and len(B.shape) == 3:
+ if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
+ return batched_igemm(A, B, out)
+
+ sA = A.shape
+ sB = B.shape
+ if transposed_A and len(sA) == 2: sA = (sA[1], sA[0])
+ elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0])
+ if transposed_B and len(sB) == 2: sB = (sB[1], sB[0])
+ elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0])
+ # this is a mess: cuBLAS expect column major, but PyTorch is row major.
+ # So to perform the matrix multiplication, we have to treat A, B, and C matrices
+ # (transpose of row major is column major)
+ # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
+
+ # matrices in the input arguments for cuBLAS
+ # column major: A @ B = C: [m, k] @ [k, n] = [m, n]
+ # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
+ # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
+ if len(sB) == 2:
+ if B.stride()[0] == B.shape[1]: transposed_B = False
+ elif B.stride()[1] == B.shape[0]: transposed_B = True
+ if len(A.shape) == 2:
+ if A.stride()[0] == A.shape[1]: transposed_A = False
+ elif A.stride()[1] == A.shape[0]: transposed_A = True
+ else:
+ if A.stride()[1] == A.shape[2]: transposed_A = False
+ elif A.stride()[2] == A.shape[1]: transposed_A = True
+
+ if len(sA) == 2:
+ n = sA[0]
+ ldb = A.stride()[1 if transposed_A else 0]
+ elif len(sA) == 3 and len(sB) == 2:
+ n = sA[0]*sA[1]
+ ldb = sA[2]
+
+
+ m = sB[1]
+ k = sB[0]
+ lda = B.stride()[(1 if transposed_B else 0)]
+ ldc = sB[1]
+ elif len(sB) == 3:
+ # special case
+ assert len(sA) == 3
+ if not (sA[0] == sB[0] and sA[1] == sB[1]):
+ raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}')
+
+ transposed_A = True
+ transposed_B = False
+
+ m = sB[2]
+ n = sA[2]
+ k = sB[0]*sB[1]
+
+ lda = m
+ ldb = sA[2]
+ ldc = m
+
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+
+ # B^T @ A^T = C^T
+ # [km, nk -> mn]
+ lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
+ get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
+ return out
+
+
+def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False):
+ if not len(A.shape) == 3 or not len(B.shape) == 3:
+ raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}')
+ sout = check_matmul(A, B, out, transposed_A, transposed_B)
+ if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
+
+ if B.is_contiguous():
+ lda = B.stride()[1]
+ transposed_A = False
+ else:
+ s = B.stride()
+ if s[0] != B.shape[0]:
+ B = B.contiguous()
+ lda = B.stride()[1]
+ elif s[2] == B.shape[1]:
+ transposed_A = True
+ lda = B.stride()[2]
+ else:
+ if s[2] == 1:
+ B = B.contiguous()
+ lda = B.stride()[1]
+ elif s[1] == 1:
+ B = B.contiguous()
+ lda = B.stride()[1]
+ else:
+ B = B.contiguous()
+ lda = B.stride()[1]
+
+ if A.is_contiguous():
+ ldb = A.stride()[1]
+ transposed_B = False
+ else:
+ s = A.stride()
+ if s[0] != A.shape[0]:
+ A = A.contiguous()
+ ldb = A.stride()[1]
+ transposed_B = False
+ elif s[2] == A.shape[1]:
+ ldb = A.stride()[2]
+ transposed_B = True
+ else:
+ A = A.contiguous()
+ ldb = A.stride()[1]
+ transposed_B = False
+
+ # this is a mess: cuBLAS expect column major, but PyTorch is row major.
+ # So to perform the matrix multiplication, we have to treat A, B, and C matrices
+ # (transpose of row major is column major)
+ # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
+ # matrices in the input arguments for cuBLAS
+
+ # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
+ # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
+ # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
+ num_batch = A.shape[0]
+ n = A.shape[1]
+ m = B.shape[2]
+ k = B.shape[1]
+
+ ldc = m
+
+ strideA = B.shape[1]*B.shape[2]
+ strideB = A.shape[1]*A.shape[2]
+ strideC = A.shape[1]*B.shape[2]
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+
+ lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
+ get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
+ ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
+ return out
+
+def igemmlt(A, B, SA, SB, out=None, Sout=None, row_scale=None, dtype=torch.int32):
+ shapeA = SA[0]
+ shapeB = SB[0]
+ dimsA = len(shapeA)
+ dimsB = len(shapeB)
+ if dimsA == 2:
+ m = shapeA[0]
+ elif dimsA == 3:
+ m = shapeA[0]*shapeA[1]
+
+ if dimsB == 2:
+ rows = n = shapeB[0]
+ elif dimsB == 3:
+ rows = n = shapeB[0]*shapeB[1]
+
+ if dimsA == 2 and out is None:
+ out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row')
+ elif dimsA == 3 and out is None:
+ out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row')
+
+ if row_scale is not None: assert row_scale.numel() == out.shape[0]
+ assert dimsB != 3, 'len(B.shape)==3 not supported'
+ assert A.device.type == 'cuda'
+ assert B.device.type == 'cuda'
+ assert A.dtype == torch.int8
+ assert B.dtype == torch.int8
+ assert out.dtype == dtype
+ assert SA[1] == 'col32'
+ assert SB[1] in ['col_turing', 'col_ampere']
+ assert Sout[1] == 'col32'
+ assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}'
+ formatB = SB[1]
+ prev_device = A.device
+ torch.cuda.set_device(A.device)
+
+ ptr = CUBLAS_Context.get_instance().get_context(A.device)
+ ptrA = get_ptr(A)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ ptrRowScale = get_ptr(row_scale)
+
+ k = shapeA[-1]
+ lda = ct.c_int32(m*32)
+ if formatB == 'col_turing':
+ # turing: tiles with rows filled up to multiple of 8 rows by 32 columns
+ # n = rows
+ ldb = ct.c_int32(((rows+7)//8)*8*32)
+ else:
+ # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
+ # n = rows
+ ldb = ct.c_int32(((rows+31)//32)*32*32)
+
+ ldc = ct.c_int32(m*32)
+ m = ct.c_int32(m)
+ n = ct.c_int32(n)
+ k = ct.c_int32(k)
+
+ has_error = 0
+ if formatB == 'col_turing':
+ if dtype == torch.int32:
+ has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ elif row_scale is None:
+ has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ else:
+ has_error = lib.cigemmlt_turing_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ elif formatB == 'col_ampere':
+ if dtype == torch.int32:
+ has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ elif row_scale is None:
+ has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+ else:
+ has_error = lib.cigemmlt_ampere_8_rowscale(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
+
+ if has_error == 1:
+ raise Exception('cublasLt ran into an error!')
+
+ torch.cuda.set_device(prev_device)
+
+
+ return out, Sout
+
+
+def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None):
+ assert A.dtype == torch.int32
+ out_shape = quant_state[0]
+ if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2])
+
+ if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
+ if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
+ if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
+ assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
+ assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
+
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ ptrRowStats = get_ptr(row_stats)
+ ptrColStats = get_ptr(col_stats)
+ ptrNewRowStats = get_ptr(new_row_stats)
+ ptrNewColStats = get_ptr(new_col_stats)
+ numRows = ct.c_int32(out_shape[0])
+ numCols = ct.c_int32(out_shape[1])
+
+ lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols)
+
+ return out
+
+
+def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0):
+ assert A.dtype == torch.float16
+ device = A.device
+
+ cols = A.shape[-1]
+ if len(A.shape) == 3:
+ rows = A.shape[0]*A.shape[1]
+ else:
+ rows = A.shape[0]
+
+ col_tiles = (cols+255)//256
+ tiled_rows = ((rows+15)//16)*16
+ if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0)
+ if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0)
+
+ if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device)
+
+ ptrA = get_ptr(A)
+ ptrRowStats = get_ptr(row_stats)
+ ptrColStats = get_ptr(col_stats)
+ ptrNnzrows = get_ptr(nnz_block_ptr)
+ rows = ct.c_int32(rows)
+ cols = ct.c_int32(cols)
+
+ prev_device = pre_call(A.device)
+ lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
+ post_call(prev_device)
+
+
+ if threshold > 0.0:
+ nnz_block_ptr.cumsum_(0)
+
+
+ return row_stats, col_stats, nnz_block_ptr
+
+class COOSparseTensor(object):
+ def __init__(self, rows, cols, nnz, rowidx, colidx, values):
+ assert rowidx.dtype == torch.int32
+ assert colidx.dtype == torch.int32
+ assert values.dtype == torch.float16
+ assert values.numel() == nnz
+ assert rowidx.numel() == nnz
+ assert colidx.numel() == nnz
+
+ self.rows = rows
+ self.cols = cols
+ self.nnz = nnz
+ self.rowidx = rowidx
+ self.colidx = colidx
+ self.values = values
+
+class CSRSparseTensor(object):
+ def __init__(self, rows, cols, nnz, rowptr, colidx, values):
+ assert rowptr.dtype == torch.int32
+ assert colidx.dtype == torch.int32
+ assert values.dtype == torch.float16
+ assert values.numel() == nnz
+ assert colidx.numel() == nnz
+ assert rowptr.numel() == rows+1
+
+ self.rows = rows
+ self.cols = cols
+ self.nnz = nnz
+ self.rowptr = rowptr
+ self.colidx = colidx
+ self.values = values
+
+class CSCSparseTensor(object):
+ def __init__(self, rows, cols, nnz, colptr, rowidx, values):
+ assert colptr.dtype == torch.int32
+ assert rowidx.dtype == torch.int32
+ assert values.dtype == torch.float16
+ assert values.numel() == nnz
+ assert rowidx.numel() == nnz
+ assert colptr.numel() == cols+1
+
+ self.rows = rows
+ self.cols = cols
+ self.nnz = nnz
+ self.colptr = colptr
+ self.rowidx = rowidx
+ self.values = values
+
+def coo2csr(cooA):
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ values.add_(1)
+ rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device)
+ rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
+ rowptr.cumsum_(0)
+ return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values)
+
+def coo2csc(cooA):
+ val, col2rowidx = torch.sort(cooA.colidx)
+ rowidx = cooA.rowidx[col2rowidx]
+ values = cooA.values[col2rowidx]
+ colvalues, counts = torch.unique(val, return_counts=True)
+ colvalues.add_(1)
+ colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device)
+ colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
+ colptr.cumsum_(0)
+ return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
+
+def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
+ rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
+ colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
+ values = torch.zeros((nnz,), dtype=dtype, device=device)
+ return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
+
+
+def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
+ device = A.device
+ assert A.dtype == torch.half
+ assert device.type == 'cuda'
+ prev_device = pre_call(A.device)
+
+ cols = A.shape[-1]
+ if len(A.shape) == 3:
+ rows = A.shape[0]*A.shape[1]
+ else:
+ rows = A.shape[0]
+
+ if row_stats is None or col_stats is None:
+ row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
+
+ if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
+ if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
+
+ coo_tensor = None
+ ptrA = get_ptr(A)
+ ptrColStats = get_ptr(col_stats)
+ ptrRowStats = get_ptr(row_stats)
+ ptrOutCol = get_ptr(out_col)
+ ptrOutRow = get_ptr(out_row)
+
+ if threshold > 0.0:
+ nnz = nnz_row_ptr[-1].item()
+ if nnz > 0:
+ coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
+ ptrRowIdx = get_ptr(coo_tensor.rowidx)
+ ptrColIdx = get_ptr(coo_tensor.colidx)
+ ptrVal = get_ptr(coo_tensor.values)
+ ptrRowPtr = get_ptr(nnz_row_ptr)
+
+ lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ val, idx = torch.sort(coo_tensor.rowidx)
+ coo_tensor.rowidx = val
+ coo_tensor.colidx = coo_tensor.colidx[idx]
+ coo_tensor.values = coo_tensor.values[idx]
+ else:
+ lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols))
+ else:
+ lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols))
+ post_call(prev_device)
+
+ return out_row, out_col, row_stats, col_stats, coo_tensor
+
+
+def get_special_format_str():
+ major, minor = torch.cuda.get_device_capability()
+ if major < 7:
+ print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!')
+ assert major >= 7
+
+ if major == 7: return 'col_turing'
+ elif major == 8: return 'col_ampere'
+ else: return 'col_turing'
+
+
+
+
+def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
+ if state is None: state = (A.shape, from_order)
+ else: from_order = state[1]
+ if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
+ else: new_state = (state[0], to_order) # (shape, order)
+
+ shape = state[0]
+ if len(shape) == 2:
+ dim1 = ct.c_int32(shape[0])
+ dim2 = ct.c_int32(shape[1])
+ else:
+ dim1 = ct.c_int32(shape[0]*shape[1])
+ dim2 = ct.c_int32(shape[2])
+
+ ptrA = get_ptr(A)
+ ptrOut = get_ptr(out)
+ if to_order == 'col32':
+ if transpose:
+ lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif to_order == 'col_turing':
+ if transpose:
+ lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif to_order == 'col_ampere':
+ if transpose:
+ lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif to_order == 'row':
+ if from_order == 'col_turing':
+ lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
+ elif from_order == 'col_ampere':
+ lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
+ else:
+ raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
+
+
+
+
+ return out, new_state
+
+def spmm_coo(cooA, B, out=None):
+ if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
+ nnz = cooA.nnz
+ assert cooA.rowidx.numel() == nnz
+ assert cooA.colidx.numel() == nnz
+ assert cooA.values.numel() == nnz
+ assert cooA.cols == B.shape[0]
+
+ transposed_B = (False if B.is_contiguous() else True)
+
+ ldb = B.stride()[(1 if transposed_B else 0)]
+ ldc = B.shape[1]
+
+ ptr = Cusparse_Context.get_instance().context
+
+ ptrRowidx = get_ptr(cooA.rowidx)
+ ptrColidx = get_ptr(cooA.colidx)
+ ptrValues = get_ptr(cooA.values)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ cnnz = ct.c_int32(cooA.nnz)
+ crowsA = ct.c_int32(cooA.rows)
+ ccolsA = ct.c_int32(cooA.cols)
+ ccolsB = ct.c_int32(B.shape[1])
+ cldb = ct.c_int32(ldb)
+ cldc = ct.c_int32(ldc)
+
+ lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
+
+ return out
+
+def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
+ if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype)
+ nnz = cooA.nnz
+ assert cooA.rowidx.numel() == nnz
+ assert cooA.colidx.numel() == nnz
+ assert cooA.values.numel() == nnz
+ assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}'
+
+ transposed_B = (False if B.is_contiguous() else True)
+
+ ldb = B.stride()[(1 if transposed_B else 0)]
+ ldc = B.shape[1]
+
+ values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ offset = counts.cumsum(0).int()
+ max_count, max_idx = torch.sort(counts, descending=True)
+ max_idx = max_idx.int()
+ max_count = max_count.int()
+ assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.'
+ assert B.dtype in [torch.float16, torch.int8]
+ ptrOffset = get_ptr(offset)
+ ptrMaxCount = get_ptr(max_count)
+ ptrMaxIdx = get_ptr(max_idx)
+
+ ptrRowidx = get_ptr(cooA.rowidx)
+ ptrColidx = get_ptr(cooA.colidx)
+ ptrValues = get_ptr(cooA.values)
+ ptrB = get_ptr(B)
+ ptrC = get_ptr(out)
+ ptrDequantStats = get_ptr(dequant_stats)
+ cnnz_rows = ct.c_int32(counts.numel())
+ cnnz = ct.c_int32(cooA.nnz)
+ crowsA = ct.c_int32(cooA.rows)
+ ccolsA = ct.c_int32(cooA.cols)
+ crowsB = ct.c_int32(B.shape[1])
+ ccolsB = ct.c_int32(B.shape[1])
+ cldb = ct.c_int32(ldb)
+ cldc = ct.c_int32(ldc)
+ #print(cooA.rowidx[:64])
+ #print(cooA.colidx[:64].sort()[0])
+
+ if B.dtype == torch.float16:
+ lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
+ elif B.dtype == torch.int8:
+ lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB)
+ #else: assertion error
+
+ return out
+
+
+C = 127.0
+
+def vectorwise_quant(x, dim=1, quant_type='vector'):
+ if quant_type == 'linear':
+ max1 = torch.abs(x).max().float()
+ xq = torch.round(x/max1*127).to(torch.int8)
+ return xq, max1
+ elif quant_type in ['vector', 'row']:
+ max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
+ xq = torch.round(x*(C/max1)).to(torch.int8)
+ return xq, max1
+ elif quant_type == 'zeropoint':
+ dtype = x.dtype
+ x = x.float()
+ dyna = x.max() - x.min()
+ if dyna == 0: dyna = 1
+ qx = 255./dyna
+ minx = x.min()
+ zpx = torch.round(minx* qx)
+ x = torch.round(qx*x - zpx) + zpx
+ return x, qx
+ elif quant_type in ['vector-zeropoint', 'row-zeropoint']:
+ dtype = x.dtype
+ x = x.float()
+ dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True))
+ dyna[dyna==0] = 1
+ qx = 255./dyna
+ minx = torch.amin(x, dim=dim, keepdim=True)
+ zpx = torch.round(minx* qx)
+ x = torch.round(qx*x - zpx) + zpx
+ return x, qx
+ elif quant_type == 'truncated-vector':
+ with torch.no_grad():
+ absx = torch.abs(x)
+ max1 = torch.amax(absx, dim=dim, keepdim=True)
+ max1 = max1*0.7
+ idx = (absx > max1.expand_as(absx))
+ sign = torch.sign(x[idx])
+ x[idx] = max1.expand_as(absx)[idx]*sign
+ xq = torch.round(x/max1*C).to(torch.int8)
+ return xq, max1
+ else: return None
+
+def vectorwise_dequant(xq, max1, quant_type='vector'):
+ if quant_type == 'vector':
+ x = (xq/C*max1).to(torch.float32)
+ return x
+ else: return None
+
+def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'):
+ if quant_type == 'linear':
+ norm = S1*S2/(C*C)
+ # double cast needed to prevent overflows
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'zeropoint':
+ norm = 1.0/(S1*S2)
+ return (xq.float()*norm).to(dtype)
+ elif quant_type == 'row-zeropoint':
+ norm = 1.0/(S1*S2)
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= norm
+ else:
+ x *= norm
+ return x.to(dtype)
+ elif quant_type == 'vector-zeropoint':
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= 1.0/S1
+ else:
+ x *= 1.0/S1
+ x *= 1.0/S2.t()
+ return x.to(dtype)
+ elif quant_type == 'row':
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= S1*S2/(C*C)
+ else:
+ x *= S1*S2/(C*C)
+ return x.to(dtype)
+ elif quant_type in ['truncated-vector', 'vector']:
+ x = xq.float()
+ if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0)
+ if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0)
+ if len(S1.shape) == 2:
+ x *= S1/C
+ else:
+ x *= S1/C
+ x *= S2/C
+ return x.to(dtype)
+ else: return None
+
+
+def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
+ offset = B.float().t().sum(0)*(SA[0]+SA[1])
+ x = xq.float()
+ if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
+ if len(SB.shape) == 2:
+ x *= SB.t()/127
+ else:
+ x *= SB/127
+ x *= SA[1]/127
+ x +=offset
+ return x.to(dtype)
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 27ad6ca..03b4655 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding, Embedding
+from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index c5460fb..5013d0b 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -3,14 +3,19 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
+import bitsandbytes as bnb
-from typing import Optional
+from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
-from torch import Tensor
+from torch import Tensor, device, dtype
+from torch import nn
+from torch.nn.parameter import Parameter
import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager
+T = TypeVar('T', bound='torch.nn.Module')
+
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
@@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
self.norm_type, self.scale_grad_by_freq, self.sparse)
return emb
+
+class Int8Params(torch.nn.Parameter):
+ def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
+ cls.has_fp16_weights = has_fp16_weights
+ cls.CB = None
+ cls.SCB = None
+ if data is None:
+ data = torch.empty(0)
+ return torch.Tensor._make_subclass(cls, data, requires_grad)
+
+ def cuda(self, device):
+ if self.has_fp16_weights:
+ return super().cuda(device)
+ else:
+ # we store the 8-bit rows-major weight
+ # we convert this weight to the turning/ampere weight during the first inference pass
+ B = self.data.contiguous().half().cuda(device)
+ CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
+ del CBt
+ del SCBt
+ self.data = CB
+ setattr(self, 'CB', CB)
+ setattr(self, 'SCB', SCB)
+
+ return self
+
+ @overload
+ def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
+ non_blocking: bool = ...) -> T:
+ ...
+
+ @overload
+ def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
+ ...
+
+ @overload
+ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
+ ...
+
+ def to(self, *args, **kwargs):
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
+
+ if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
+ else:
+ new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
+ new_param.CB = self.CB
+ new_param.SCB = self.SCB
+
+ return new_param
+
+
+
+class Linear8bitLt(nn.Linear):
+ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
+ super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+ self.state = bnb.MatmulLtState()
+ self.index=index
+
+ self.state.threshold = threshold
+ self.state.has_fp16_weights = has_fp16_weights
+ if threshold > 0.0 and not has_fp16_weights:
+ self.state.use_pool = True
+
+ self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+
+ def init_8bit_state(self):
+ self.state.CB = self.weight.CB
+ self.state.SCB = self.weight.SCB
+ self.weight.CB = None
+ self.weight.SCB = None
+
+ def forward(self, x):
+ self.state.is_training = self.training
+
+ if self.weight.CB is not None: self.init_8bit_state()
+ #assert not self.state.has_fp16_weights
+ #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
+
+ out = bnb.matmul(x, self.weight, state=self.state)
+
+ if self.bias is not None:
+ out += self.bias.unsqueeze(0).expand_as(out)
+
+ if not self.state.has_fp16_weights and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+
+ return out
+
+class Linear8bit(nn.Linear):
+ def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
+ super(Linear8bit, self).__init__(input_features, output_features, bias)
+ self.quant_type = quant_type
+ self.index = index
+ self.args = args
+ self.iter = 0
+
+ def forward(self, x):
+ self.iter += 1
+ if self.iter % self.args.clip_freq == 0:
+ with torch.no_grad():
+ maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ print('clip', maxval[-1].item())
+ self.weight.clip_(-maxval[-1], maxval[-1])
+
+
+ if self.args is not None:
+ out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
+ else:
+ out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
+
+ return out