From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- bitsandbytes/autograd/_functions.py | 134 +++++++++++++++++++++++------------- 1 file changed, 88 insertions(+), 46 deletions(-) (limited to 'bitsandbytes/autograd') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e641583..a08b560 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,21 +1,24 @@ +from dataclasses import dataclass + import torch + import bitsandbytes as bnb import bitsandbytes.functional as F -from dataclasses import dataclass - tensor = torch.Tensor -''' +""" This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. -''' +""" + + class GlobalOutlierPooler(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.outliers = set() @@ -29,25 +32,29 @@ class GlobalOutlierPooler(object): return cls._instance def add_outliers(self, outlier_idx, feature_dim): - if self.model_dim is None: self.model_dim = feature_dim - if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer + if self.model_dim is None: + self.model_dim = feature_dim + if feature_dim != self.model_dim: + return # we do not encode outliers for the 2nd FFN layer self.outliers.update(outlier_idx.tolist()) def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -class MatMul8bit(torch.autograd.Function): +class MatMul8bit(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]): + def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): if precision[0] != 8: with torch.no_grad(): output = torch.matmul(A, B) else: - if len(B.shape) == 2: dim = 0 - else: dim = 1 + if len(B.shape) == 2: + dim = 0 + else: + dim = 1 qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) iout = F.igemm(qA, qB) @@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function): else: if len(B.shape) == 2 and len(A.shape) == 3: grad_output = grad_output.contiguous() - if not grad_output.is_contiguous(): grad_output.contiguous() - qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type) - if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) + if not grad_output.is_contiguous(): + grad_output.contiguous() + qgrad_output, S1 = F.vectorwise_quant( + grad_output.view(-1, grad_output.shape[2]), + dim=0, + quant_type=quant_type, + ) + if not A.is_contiguous(): + A = A.contiguous() + qA, S2 = F.vectorwise_quant( + A.view(-1, A.shape[2]), dim=0, quant_type=quant_type + ) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, S2.t(), S1, grad_output.dtype, quant_type + ) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, + S2.permute(permute_dim), + S1, + grad_output.dtype, + quant_type, + ) if A.requires_grad: - if len(grad_output.shape) == 3: dims = [2] - else: dims = [1] + if len(grad_output.shape) == 3: + dims = [2] + else: + dims = [1] if len(B.shape) == 3: # bio -> boi @@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) - grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type) + grad_A = F.vectorwise_mm_dequant( + igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type + ) return grad_A, grad_B, None, None, None @@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply bmm_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply + @dataclass class MatmulLtState: CB = None @@ -159,7 +191,6 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): - @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): # 1. Quantize A @@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function): requires_gradB = B.requires_grad formatB = state.formatB input_shape = A.shape - if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!' + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + assert ( + A.dtype == torch.float16 + ), f"The input data type needs to be fp16 but {A.dtype} was found!" # 1. Quantize A - if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: @@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function): # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() - #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: + # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() + # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB # outlier_idx = torch.unique(coo_tensorA.colidx).long() # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) @@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = outlier_idx # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - #if state.idx is not None: + # if state.idx is not None: # # extract outliers # CA[:, state.idx] = 0 # CAt[:, state.idx] = 0 # subA = A[:, state.idx] - #else: + # else: # subA = None else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None - # 2. Quantize B if state.has_fp16_weights: - has_grad = (True if (getattr(B, 'grad', None) is not None) else False) + has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) - if is_transposed: B = B.contiguous() + if is_transposed: + B = B.contiguous() if (state.is_training and not has_grad) or state.CxB is None: state.reset_grads() @@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: # # do not use pool for 2nd FFN layer # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - #else: + # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half() + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half() + ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) output = F.mm_dequant(out32, Sout32, SCA, state.SCB) @@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x clone_func = torch.clone return clone_func(output.view(output_shape)) @@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.' + assert state.has_fp16_weights, "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() @@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, 'col32') + C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view( + ctx.grad_shape + ) return grad_A, grad_B, None, None, None, None, None @@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function): matmul = MatMul8bitLt.apply -def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0): +def matmul( + A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0 +): state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, state) - -- cgit v1.2.3