diff options
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 55 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 11 | ||||
-rw-r--r-- | tests/test_autograd.py | 49 |
3 files changed, 62 insertions, 53 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 01e7073..4dbf129 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -201,13 +201,14 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, state=MatmulLtState()): + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # default to pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: ctx.is_empty = True ctx.A = A ctx.B = B + ctx.bias = bias if A.shape[-1] == B.shape[0]: return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) else: @@ -220,6 +221,7 @@ class MatMul8bitLt(torch.autograd.Function): # 5. Save state requires_gradA = A.requires_grad requires_gradB = B.requires_grad + requires_gradBias = bias is not None and bias.requires_grad formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: @@ -247,28 +249,7 @@ class MatMul8bitLt(torch.autograd.Function): if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform( - state.CB, to_order=formatB - ) - # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() - # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: - # # generate outlier index and subB - # outlier_idx = torch.unique(coo_tensorA.colidx).long() - # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: - # # do not use pool for 2nd FFN layer - # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - # else: - # state.idx = outlier_idx - # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - - # if state.idx is not None: - # # extract outliers - # CA[:, state.idx] = 0 - # CAt[:, state.idx] = 0 - # subA = A[:, state.idx] - # else: - # subA = None + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) @@ -326,7 +307,8 @@ class MatMul8bitLt(torch.autograd.Function): # 3. Matmul C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - output = F.mm_dequant(out32, Sout32, SCA, state.SCB) + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -337,7 +319,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.formatB = formatB ctx.grad_shape = input_shape - ctx.req_grads = [requires_gradA, requires_gradB] + ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] if requires_gradA or requires_gradB: ctx.tensors = (CAt, subA) @@ -347,15 +329,16 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - clone_func = torch.clone + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + #clone_func = torch.clone return clone_func(output.view(output_shape)) @staticmethod def backward(ctx, grad_output): if ctx.is_empty: - return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None - req_gradA, req_gradB = ctx.req_grads + bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, req_gradBias = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB @@ -369,7 +352,7 @@ class MatMul8bitLt(torch.autograd.Function): -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = None + grad_A = grad_B = grad_bias = None Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: @@ -387,11 +370,12 @@ class MatMul8bitLt(torch.autograd.Function): state.CBt, to_order=formatB, transpose=True ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view( - ctx.grad_shape - ) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + + if req_gradBias: + grad_bias = grad_output.sum(0) - return grad_A, grad_B, None, None + return grad_A, grad_B, None, grad_bias, None matmul = MatMul8bitLt.apply @@ -403,8 +387,9 @@ def matmul( out: tensor = None, state: MatmulLtState = None, threshold=0.0, + bias=None ): state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold - return MatMul8bitLt.apply(A, B, out, state) + return MatMul8bitLt.apply(A, B, out, bias, state) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 454dba5..24ecf39 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -235,9 +235,7 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params( - self.weight.data, has_fp16_weights=has_fp16_weights - ) + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -250,13 +248,12 @@ class Linear8bitLt(nn.Linear): if self.weight.CB is not None: self.init_8bit_state() + if self.bias.dtype != torch.float16: + self.bias.data = self.bias.data.half() # assert not self.state.has_fp16_weights # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None - out = bnb.matmul(x, self.weight, state=self.state) - - if self.bias is not None: - out += self.bias.unsqueeze(0).expand_as(out) + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) if not self.state.has_fp16_weights and self.state.CB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass diff --git a/tests/test_autograd.py b/tests/test_autograd.py index f1a15f5..0cd17c9 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,4 +1,4 @@ -from itertools import product +from itertools import product, permutations import pytest import torch @@ -241,11 +241,20 @@ decomp = [0.0, 6.0] funcs = [(torch.matmul, bnb.matmul)] str_funcs = ["matmul"] req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ["FF", "TF", "TT", "FT"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] dtype = [torch.float16] has_fp16_weights = [True, False] +has_bias = [True, False] values = list( product( dim1, @@ -258,6 +267,7 @@ values = list( transpose, decomp, has_fp16_weights, + has_bias ) ) str_values = list( @@ -272,18 +282,14 @@ str_values = list( str_transpose, decomp, has_fp16_weights, + has_bias ) ) -names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format( - *vals - ) - for vals in str_values -] +names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values] @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", values, ids=names, ) @@ -298,10 +304,14 @@ def test_matmullt( transpose, decomp, has_fp16_weights, + has_bias ): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False for i in range(k): @@ -322,6 +332,11 @@ def test_matmullt( requires_grad=req_grad[1], dtype=dtype, ) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -342,10 +357,13 @@ def test_matmullt( if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2, state=state) + out_bnb = funcs[1](A, B2, state=state, bias=bias2) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2.t(), state=state) + out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2) + + if has_bias: + out_torch += bias n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() @@ -367,6 +385,9 @@ def test_matmullt( gradB1 = B.grad A.grad = None B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None loss_torch = torch.nn.functional.mse_loss( out_torch, target @@ -376,6 +397,9 @@ def test_matmullt( gradB2 = B.grad A.grad = None B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None if req_grad[0]: torch.testing.assert_allclose( @@ -397,3 +421,6 @@ def test_matmullt( torch.testing.assert_allclose( gradB1, gradB2, atol=0.18, rtol=0.3 ) + + if req_grad[2]: + torch.testing.assert_allclose(gradBias1, gradBias2) |