diff options
author | Tim Dettmers <TimDettmers@users.noreply.github.com> | 2022-09-19 21:09:25 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-19 21:09:25 -0700 |
commit | 439f2b0c10abd3e9aade386d92810b074c69e9ec (patch) | |
tree | 75454081c86ba1c96c07e83defc9fc5f4de840cf /tests | |
parent | 9b5f2eda8fbd3f042c4af7ed1b870525d4668f2a (diff) | |
parent | 76ce9aa6da7d68d2463f0f3e99532ab5b6db58a8 (diff) |
Merge pull request #33 from dbaranchuk/memory-efficient-backward
Memory efficient backward
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_autograd.py | 9 | ||||
-rw-r--r-- | tests/test_modules.py | 46 |
2 files changed, 41 insertions, 14 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py index bae26de..40bb441 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -253,7 +253,7 @@ for c in req_grad: transpose = [(False, True), (False, False)] str_transpose = ["NT", "NN"] -dtype = [torch.float16] +dtype = [torch.float16, torch.bfloat16, torch.float32] has_fp16_weights = [True, False] has_bias = [True, False] values = list( @@ -354,7 +354,7 @@ def test_matmullt( state.SCB, SCBt, coo_tensorB, - ) = bnb.functional.double_quant(B2) + ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: @@ -367,11 +367,14 @@ def test_matmullt( if has_bias: out_torch += bias + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).mean().item() # print(f'abs error {err:.4f}') + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx == 0).sum().item() <= n * 0.0175 + assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) assert (idx == 0).sum().item() <= n * 0.001 diff --git a/tests/test_modules.py b/tests/test_modules.py index c0b3311..2879846 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -14,13 +14,15 @@ class MockArgs(object): class MLP8bit(torch.nn.Module): - def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super(MLP8bit, self).__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold + dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold + dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) def forward(self, x): @@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) -def test_linear8bitlt_no_fp16_weights(threshold): +@pytest.mark.parametrize("memory_efficient_backward", [True, False]) +def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( - bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) + bnb.nn.Linear8bitLt( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .cuda() .half() ) @@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .half() .to("cuda") ) @@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" - mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) - .to(torch.float16) - .to("cuda") - ) + mlp = MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) + w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, + mlp = mlp.cuda().half() # and this line triggers quantization for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.state.idx is not None if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" + if memory_efficient_backward: + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) + + mlp.zero_grad() + (o1 * grad_proj).sum().backward() + grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() + scale = grad_ref.abs().mean() + + torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) + assert (idx == 0).sum().item() <= b1.numel() * 0.005 + def test_linear8bitlt_fp32_bias(): # casts model to fp16 -> int8 automatically |