From 591f60395a1e9c62f291e23c91af45cc699f072c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Sun, 18 Sep 2022 00:52:53 +0300 Subject: add memory efficient backward --- bitsandbytes/autograd/_functions.py | 1 - tests/test_modules.py | 24 +++++++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 6674a82..daf9ba0 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -381,7 +381,6 @@ class MatMul8bitLt(torch.autograd.Function): grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: - raise NotImplementedError("WIP") CB = state.CB.to(ctx.dtype_B) CB.mul_(state.SCB.unsqueeze(1).div_(127.0).to(CB.dtype)) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) diff --git a/tests/test_modules.py b/tests/test_modules.py index c0b3311..53a675f 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -14,13 +14,15 @@ class MockArgs(object): class MLP8bit(torch.nn.Module): - def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): + def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): super(MLP8bit, self).__init__() self.fc1 = bnb.nn.Linear8bitLt( - dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold + dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) self.fc2 = bnb.nn.Linear8bitLt( - dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold + dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold ) def forward(self, x): @@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values] @pytest.mark.parametrize("threshold", values, ids=names) -def test_linear8bitlt_no_fp16_weights(threshold): +@pytest.mark.parametrize("memory_efficient_backward", [True, False]) +def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): l1 = ( - bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) + bnb.nn.Linear8bitLt( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .cuda() .half() ) @@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .half() .to("cuda") ) @@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.device.type == "cuda" mlp = ( - MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) .to(torch.float16) .to("cuda") ) @@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.device.type == "cuda" + def test_linear8bitlt_fp32_bias(): # casts model to fp16 -> int8 automatically l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() -- cgit v1.2.3