diff options
author | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-09-11 05:51:29 +0300 |
---|---|---|
committer | dbaranchuk <dmitrybaranchuk@gmail.com> | 2022-09-11 05:51:29 +0300 |
commit | 42b5fc9acc4b59a6d90c662eb26099ac25907c7f (patch) | |
tree | df0f65f65e2f1aae25462da1be9c65ca3fe45580 | |
parent | 843ad0631c65eabc7f64e80906ecf5482cc1a036 (diff) |
add memory effcient backward option
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 46 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 16 |
2 files changed, 52 insertions, 10 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 226cbb5..271c690 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,5 +1,6 @@ import operator import torch +import bitsandbytes as bnb import bitsandbytes.functional as F from dataclasses import dataclass @@ -187,6 +188,8 @@ class MatmulLtState: use_pool = False formatB = F.get_special_format_str() + memory_efficient_backward = False + def reset_grads(self): self.CB = None self.CxB = None @@ -283,6 +286,12 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -332,13 +341,15 @@ class MatMul8bitLt(torch.autograd.Function): clone_func = torch.clone if len(output_shape) == 3 else lambda x : x return clone_func(output.view(output_shape)) + @staticmethod def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, req_gradBias = ctx.req_grads - assert not req_gradB, "TODO: support weight updates as well" + CAt, subA = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB state = ctx.state # Cast grad_output to fp16 @@ -352,11 +363,31 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = grad_B = grad_bias = None + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + if req_gradB: + CxAt, SAt = F.transform(CAt, formatB, transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) + gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + if req_gradA: - CB = state.CB.half() - SCB = (state.SCB.unsqueeze(1) / 127.0).half() - CB *= SCB - grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + if state.CBt: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + elif state.CB: + CB = state.CB.half() + SCB = (state.SCB.unsqueeze(1) / 127.0).half() + CB *= SCB + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + else: + raise Exception('State must contain either CBt or CB matrix') if req_gradBias: grad_bias = grad_output.sum(0) @@ -367,6 +398,9 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None +matmul = MatMul8bitLt.apply + + def matmul( A: tensor, B: tensor, diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3e32c8e..00d0c61 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -223,6 +223,7 @@ class Linear8bitLt(nn.Linear): has_fp16_weights=True, threshold=0.0, index=None, + memory_efficient_backward=False ): super(Linear8bitLt, self).__init__( input_features, output_features, bias @@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear): self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True @@ -255,10 +257,16 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CxB is not None: - # In this version, we convert 8-bit row major to turing/ampere format at each inference pass - # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. - del self.state.CxB + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB return out |