diff options
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 54 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 9 |
2 files changed, 26 insertions, 37 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index be975f6..226cbb5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -195,7 +195,6 @@ class MatmulLtState: self.CxBt = None self.SBt = None - self.CBt = None class MatMul8bitLt(torch.autograd.Function): @@ -225,6 +224,11 @@ class MatMul8bitLt(torch.autograd.Function): input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + A_dtype = A.dtype + A = A.to(torch.float16) + assert ( A.dtype == torch.float16 ), f"The input data type needs to be fp16 but {A.dtype} was found!" @@ -279,12 +283,6 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( (outliers * state.SCB.view(-1, 1) / 127.0) @@ -328,52 +326,44 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + # Cast fp16 output back to A.dtype + output = output.to(A_dtype) + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - #clone_func = torch.clone return clone_func(output.view(output_shape)) - @staticmethod def backward(ctx, grad_output): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, req_gradBias = ctx.req_grads - CAt, subA = ctx.tensors - SCAt, idx = ctx.tensor_states - formatB = ctx.formatB + assert not req_gradB, "TODO: support weight updates as well" state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." + + # Cast grad_output to fp16 + grad_output_dtype = grad_output.dtype + grad_output = grad_output.to(torch.float16) if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() grad_A = grad_B = grad_bias = None - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) - if req_gradB: - CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) - gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) - if state.threshold > 0.0 and subA is not None: - grad_B[:, idx] += torch.matmul(grad_output.t(), subA) - if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + CB = state.CB.half() + SCB = (state.SCB.unsqueeze(1) / 127.0).half() + CB *= SCB + grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape) if req_gradBias: grad_bias = grad_output.sum(0) + # Cast grad_A back to grad_output_dtype + grad_output.to(grad_output_dtype) + return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..3e32c8e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -255,11 +255,10 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if not self.state.has_fp16_weights and self.state.CxB is not None: + # In this version, we convert 8-bit row major to turing/ampere format at each inference pass + # Thus, we delete CxB from the state. TODO: do not store it in the state in the first place. + del self.state.CxB return out |