summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorjustheuristic <justheuristic@gmail.com>2022-09-18 00:52:53 +0300
committerjustheuristic <justheuristic@gmail.com>2022-09-18 00:52:53 +0300
commit591f60395a1e9c62f291e23c91af45cc699f072c (patch)
tree1f4ff32a1e490d9a872286cdf2d4f43eb0f1df2a /tests
parent579b8c782f5240d589ca65ef950054734db97ae1 (diff)
add memory efficient backward
Diffstat (limited to 'tests')
-rw-r--r--tests/test_modules.py24
1 files changed, 17 insertions, 7 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index c0b3311..53a675f 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -14,13 +14,15 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module):
- def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
+ def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(
- dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
+ dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
+ threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
- dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
+ dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
+ threshold=threshold
)
def forward(self, x):
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
-def test_linear8bitlt_no_fp16_weights(threshold):
+@pytest.mark.parametrize("memory_efficient_backward", [True, False])
+def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
- bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
+ bnb.nn.Linear8bitLt(
+ 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
+ )
.cuda()
.half()
)
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
mlp = (
- MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ MLP8bit(
+ 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
+ )
.half()
.to("cuda")
)
@@ -532,7 +539,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.device.type == "cuda"
mlp = (
- MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ MLP8bit(
+ 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
+ )
.to(torch.float16)
.to("cuda")
)
@@ -551,6 +560,7 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.device.type == "cuda"
+
def test_linear8bitlt_fp32_bias():
# casts model to fp16 -> int8 automatically
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()