diff options
-rw-r--r-- | CHANGELOG.md | 11 | ||||
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 76 | ||||
-rw-r--r-- | bitsandbytes/cuda_setup/main.py | 2 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 21 | ||||
-rw-r--r-- | setup.py | 4 | ||||
-rw-r--r-- | tests/test_autograd.py | 9 | ||||
-rw-r--r-- | tests/test_modules.py | 46 |
7 files changed, 114 insertions, 55 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 40467dc..a26a0e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -106,3 +106,14 @@ Bug fixes: - fixed an import of bnb.utils 2e630b55f51d454f3bd723dffda68a07ef93190c We thank @mryab, @mbrukman, @chessgecko, @dbaranchuk for pull request with bug fixes and new features. + + +### 0.34.0 + +#### Bug fixes and memory efficient backprop + +Features: + - Linear8bitLt layer now supports `memory_efficient_backward=True` which enables backprop of gradients through frozen weights. + +Bug fixes: + - fixed an issue where too many threads were created in blockwise quantization on the CPU for large tensors diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index be975f6..2ddb406 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,4 +1,6 @@ import operator +import warnings + import torch import bitsandbytes.functional as F @@ -184,6 +186,7 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() @@ -209,31 +212,29 @@ class MatMul8bitLt(torch.autograd.Function): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B # 3. Matmul # 4. Mixed-precision decomposition matmul # 5. Save state - requires_gradA = A.requires_grad - requires_gradB = B.requires_grad - requires_gradBias = bias is not None and bias.requires_grad formatB = state.formatB input_shape = A.shape if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - assert ( - A.dtype == torch.float16 - ), f"The input data type needs to be fp16 but {A.dtype} was found!" + + # Cast A to fp16 + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( - A, threshold=state.threshold + A.to(torch.float16), threshold=state.threshold ) if state.threshold > 0.0 and coo_tensorA is not None: @@ -269,7 +270,7 @@ class MatMul8bitLt(torch.autograd.Function): state.SCB, state.SCBt, coo_tensorB, - ) = F.double_quant(B) + ) = F.double_quant(B.to(torch.float16)) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False @@ -290,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function): (outliers * state.SCB.view(-1, 1) / 127.0) .t() .contiguous() - .half() + .to(A.dtype) ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -307,7 +308,13 @@ class MatMul8bitLt(torch.autograd.Function): C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) # we apply the fused bias here - output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul if coo_tensorA is not None and subA is not None: @@ -318,9 +325,9 @@ class MatMul8bitLt(torch.autograd.Function): ctx.formatB = formatB ctx.grad_shape = input_shape - ctx.req_grads = [requires_gradA, requires_gradB, requires_gradBias] + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype - if requires_gradA or requires_gradB: + if any(ctx.needs_input_grad[:2]): ctx.tensors = (CAt, subA) ctx.tensor_states = (SCAt, state.idx) else: @@ -328,8 +335,8 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x - #clone_func = torch.clone return clone_func(output.view(output_shape)) @staticmethod @@ -337,23 +344,24 @@ class MatMul8bitLt(torch.autograd.Function): if ctx.is_empty: bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None - req_gradA, req_gradB, req_gradBias = ctx.req_grads + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert ( - state.has_fp16_weights - ), "Backprop only supported for fp16 weights." + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + # Cast grad_output to fp16 if len(grad_output.shape) == 3: - grad_output = grad_output.view( + grad_output = grad_output.reshape( -1, grad_output.shape[-1] ).contiguous() - grad_A = grad_B = grad_bias = None - - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) @@ -363,16 +371,20 @@ class MatMul8bitLt(torch.autograd.Function): grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, "col32") - if state.CxBt is None: - state.CxBt, state.SBt = F.transform( - state.CBt, to_order=formatB, transpose=True - ) - gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) - if req_gradBias: - grad_bias = grad_output.sum(0) + elif state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + else: + raise Exception('State must contain either CBt or CB matrix for backward') return grad_A, grad_B, None, grad_bias, None diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index 78a2844..f11b430 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -103,7 +103,7 @@ def get_compute_capability(cuda): None. """ ccs = get_compute_capabilities(cuda) - if ccs is not None: + if ccs: # TODO: handle different compute capabilities; for now, take the max return ccs[-1] return None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index b222f54..9250fec 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -221,6 +221,7 @@ class Linear8bitLt(nn.Linear): output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None, ): @@ -232,10 +233,13 @@ class Linear8bitLt(nn.Linear): self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB @@ -255,11 +259,16 @@ class Linear8bitLt(nn.Linear): out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) - if not self.state.has_fp16_weights and self.state.CB is not None: - # we converted 8-bit row major to turing/ampere format in the first inference pass - # we no longer need the row-major weight - del self.state.CB - self.weight.data = self.state.CxB + if not self.state.has_fp16_weights: + if not self.state.memory_efficient_backward and self.state.CB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + elif self.state.memory_efficient_backward and self.state.CxB is not None: + # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. + # Thus, we delete CxB from the state. + del self.state.CxB return out @@ -18,13 +18,13 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.33.1", + version=f"0.34.0", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", description="8-bit optimizers and matrix multiplication routines.", license="MIT", keywords="gpu optimizers optimization 8-bit quantization compression", - url="http://packages.python.org/bitsandbytes", + url="https://github.com/TimDettmers/bitsandbytes", packages=find_packages(), entry_points={ "console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"], diff --git a/tests/test_autograd.py b/tests/test_autograd.py index bae26de..40bb441 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -253,7 +253,7 @@ for c in req_grad: transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] -dtype = [torch.float16] +dtype = [torch.float16, torch.bfloat16, torch.float32] has_fp16_weights = [True, False] has_bias = [True, False] values = list( @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2) + ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: @@ -367,11 +367,14 @@ def test_matmullt( if has_bias: out_torch += bias + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() <= n * 0.0175 + assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) assert (idx == 0).sum().item() <= n * 0.001 diff --git a/tests/test_modules.py b/tests/test_modules.py index c0b3311..2879846 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -14,13 +14,15 @@ class MockArgs(object): class MLP8bit(torch.nn.Module): - def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super(MLP8bit, self).__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold + dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold + dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) def forward(self, x): @@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) -def test_linear8bitlt_no_fp16_weights(threshold): +@pytest.mark.parametrize("memory_efficient_backward", [True, False]) +def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( - bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) + bnb.nn.Linear8bitLt( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .cuda() .half() ) @@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .half() .to("cuda") ) @@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .to(torch.float16) - .to("cuda") - ) + mlp = MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) + w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, + mlp = mlp.cuda().half() # and this line triggers quantization for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.state.idx is not None if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" + if memory_efficient_backward: + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) + + mlp.zero_grad() + (o1 * grad_proj).sum().backward() + grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() + scale = grad_ref.abs().mean() + + torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) + assert (idx == 0).sum().item() <= b1.numel() * 0.005 + def test_linear8bitlt_fp32_bias(): # casts model to fp16 -> int8 automatically |