diff options
Diffstat (limited to 'bitsandbytes')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 46 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 17 |
2 files changed, 43 insertions, 20 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index be975f6..bdcbec5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -184,6 +184,7 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() @@ -225,9 +226,10 @@ class MatMul8bitLt(torch.autograd.Function): input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - assert ( - A.dtype == torch.float16 - ), f"The input data type needs to be fp16 but {A.dtype} was found!" + + # Cast A to fp16 + A_dtype = A.dtype + A = A.to(torch.float16) # 1. Quantize A if len(A.shape) == 3: @@ -328,8 +330,10 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + # Cast fp16 output back to A.dtype + output = output.to(A_dtype) + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - #clone_func = torch.clone return clone_func(output.view(output_shape)) @staticmethod @@ -342,12 +346,13 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." + + # Cast grad_output to fp16 + grad_output_dtype = grad_output.dtype + grad_output = grad_output.to(torch.float16) if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() @@ -363,17 +368,28 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + elif state.CB is not None: + CB = state.CB.half() + SCB = (state.SCB.unsqueeze(1) / 127.0).half() + CB *= SCB + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) + else: + raise Exception('State must contain either CBt or CB matrix for backward') if req_gradBias: grad_bias = grad_output.sum(0) + # Cast grad_A back to grad_output_dtype + grad_output = grad_output.to(grad_output_dtype) + return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..e7e759d 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear): output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None, ): @@ -232,6 +233,7 @@ class Linear8bitLt(nn.Linear): self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True @@ -255,11 +257,16 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB return out |