From 42b5fc9acc4b59a6d90c662eb26099ac25907c7f Mon Sep 17 00:00:00 2001 From: dbaranchuk Date: Sun, 11 Sep 2022 05:51:29 +0300 Subject: add memory effcient backward option --- bitsandbytes/autograd/_functions.py | 46 ++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) (limited to 'bitsandbytes/autograd/_functions.py') diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 226cbb5..271c690 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,5 +1,6 @@ import operator import torch +import bitsandbytes as bnb import bitsandbytes.functional as F from dataclasses import dataclass @@ -187,6 +188,8 @@ class MatmulLtState: use_pool = False formatB = F.get_special_format_str() + memory_efficient_backward = False + def reset_grads(self): self.CB = None self.CxB = None @@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function): clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) + @staticmethod def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, req_gradBias = ctx.req_grads - assert not req_gradB, "TODO: support weight updates as well" + CAt, subA = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB state = ctx.state # Cast grad_output to fp16 @@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = grad_B = grad_bias = None + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + if req_gradB: + CxAt, SAt = F.transform(CAt, formatB, transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) + gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + if req_gradA: - CB = state.CB.half() - SCB = (state.SCB.unsqueeze(1) / 127.0).half() - CB *= SCB - grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + if state.CBt: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + elif state.CB: + CB = state.CB.half() + SCB = (state.SCB.unsqueeze(1) / 127.0).half() + CB *= SCB + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + else: + raise Exception('State must contain either CBt or CB matrix') if req_gradBias: grad_bias = grad_output.sum(0) @@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None +matmul = MatMul8bitLt.apply + + def matmul( A: tensor, B: tensor, -- cgit v1.2.3