From de354f7ded52bfa857089769225cdf1ee694bfd6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 16 Aug 2022 12:00:54 -0700 Subject: Added fused bias to matmullt. --- bitsandbytes/autograd/_functions.py | 55 ++++++++++++++----------------------- bitsandbytes/nn/modules.py | 11 +++----- 2 files changed, 24 insertions(+), 42 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 01e7073..4dbf129 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -201,13 +201,14 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, state=MatmulLtState()): + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: ctx.is_empty = True ctx.A = A ctx.B = B + ctx.bias = bias if A.shape[-1] == B.shape[0]: return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) else: @@ -220,6 +221,7 @@ class MatMul8bitLt(torch.autograd.Function): # 5. Save state requires_gradA = A.requires_grad requires_gradB = B.requires_grad + requires_gradBias = bias is not None and bias.requires_grad formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: @@ -247,28 +249,7 @@ class MatMul8bitLt(torch.autograd.Function): if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform( - state.CB, to_order=formatB - ) - # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() - # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: - # # generate outlier index and subB - # outlier_idx = torch.unique(coo_tensorA.colidx).long() - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx - # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - - # if state.idx is not None: - # # extract outliers - # CA[:, state.idx] = 0 - # CAt[:, state.idx] = 0 - # subA = A[:, state.idx] - # else: - # subA = None + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -326,7 +307,8 @@ class MatMul8bitLt(torch.autograd.Function): # 3. Matmul C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - output = F.mm_dequant(out32, Sout32, SCA, state.SCB) + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -337,7 +319,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.formatB = formatB ctx.grad_shape = input_shape - ctx.req_grads = [requires_gradA, requires_gradB] + ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] if requires_gradA or requires_gradB: ctx.tensors = (CAt, subA) @@ -347,15 +329,16 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - clone_func = torch.clone + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + #clone_func = torch.clone return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None - req_gradA, req_gradB = ctx.req_grads + bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, req_gradBias = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB @@ -369,7 +352,7 @@ class MatMul8bitLt(torch.autograd.Function): -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = None + grad_A = grad_B = grad_bias = None Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: @@ -387,11 +370,12 @@ class MatMul8bitLt(torch.autograd.Function): state.CBt, to_order=formatB, transpose=True ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view( - ctx.grad_shape - ) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + + if req_gradBias: + grad_bias = grad_output.sum(0) - return grad_A, grad_B, None, None + return grad_A, grad_B, None, grad_bias, None matmul = MatMul8bitLt.apply @@ -403,8 +387,9 @@ def matmul( out: tensor = None, state: MatmulLtState = None, threshold=0.0, + bias=None ): state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold - return MatMul8bitLt.apply(A, B, out, state) + return MatMul8bitLt.apply(A, B, out, bias, state) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 454dba5..24ecf39 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear): if self.weight.CB is not None: self.init_8bit_state() + if self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() # assert not self.state.has_fp16_weights # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None - out = bnb.matmul(x, self.weight, state=self.state) - - if self.bias is not None: - out += self.bias.unsqueeze(0).expand_as(out) + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass -- cgit v1.2.3