diff options
author | Tim Dettmers <TimDettmers@users.noreply.github.com> | 2022-09-19 21:09:25 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-19 21:09:25 -0700 |
commit | 439f2b0c10abd3e9aade386d92810b074c69e9ec (patch) | |
tree | 75454081c86ba1c96c07e83defc9fc5f4de840cf /bitsandbytes/autograd | |
parent | 9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff) | |
parent | 76ce9aa6da7d68d2463f0f3e99532ab5b6db58a8 (diff) |
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 76 |
1 files changed, 44 insertions, 32 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index be975f6..2ddb406 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,6 @@ import operator +import warnings + import torch import bitsandbytes.functional as F @@ -184,6 +186,7 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() @@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - requires_gradA = A.requires_grad - requires_gradB = B.requires_grad - requires_gradBias = bias is not None and bias.requires_grad formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - assert ( - A.dtype == torch.float16 - ), f"The input data type needs to be fp16 but {A.dtype} was found!" + + # Cast A to fp16 + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A, threshold=state.threshold + A.to(torch.float16), threshold=state.threshold ) if state.threshold > 0.0 and coo_tensorA is not None: @@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function): state.SCB, state.SCBt, coo_tensorB, - ) = F.double_quant(B) + ) = F.double_quant(B.to(torch.float16)) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False @@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function): (outliers * state.SCB.view(-1, 1) / 127.0) .t() .contiguous() - .half() + .to(A.dtype) ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function): C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.formatB = formatB ctx.grad_shape = input_shape - ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype - if requires_gradA or requires_gradB: + if any(ctx.needs_input_grad[:2]): ctx.tensors = (CAt, subA) ctx.tensor_states = (SCAt, state.idx) else: @@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - #clone_func = torch.clone return clone_func(output.view(output_shape)) @staticmethod @@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, req_gradBias = ctx.req_grads + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = grad_bias = None - - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) @@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - if req_gradBias: - grad_bias = grad_output.sum(0) + elif state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + else: + raise Exception('State must contain either CBt or CB matrix for backward') return grad_A, grad_B, None, grad_bias, None |