summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py46
1 files changed, 35 insertions, 11 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index c0b3311..2879846 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -14,13 +14,15 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module):
- def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
+ def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(
- dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
+ dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
+ threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
- dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
+ dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
+ threshold=threshold
)
def forward(self, x):
@@ -451,9 +453,12 @@ names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
-def test_linear8bitlt_no_fp16_weights(threshold):
+@pytest.mark.parametrize("memory_efficient_backward", [True, False])
+def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = (
- bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
+ bnb.nn.Linear8bitLt(
+ 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
+ )
.cuda()
.half()
)
@@ -513,7 +518,9 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8
mlp = (
- MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
+ MLP8bit(
+ 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
+ )
.half()
.to("cuda")
)
@@ -531,11 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
- mlp = (
- MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
- .to(torch.float16)
- .to("cuda")
- )
+ mlp = MLP8bit(
+ 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
+ )
+ w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization,
+ mlp = mlp.cuda().half() # and this line triggers quantization
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()
@@ -545,11 +552,28 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
+
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
+ if memory_efficient_backward:
+ b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ assert o1.requires_grad
+ grad_proj = torch.randn_like(o1)
+
+ mlp.zero_grad()
+ (o1 * grad_proj).sum().backward()
+ grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
+ scale = grad_ref.abs().mean()
+
+ torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
+ idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
+ assert (idx == 0).sum().item() <= b1.numel() * 0.005
+
def test_linear8bitlt_fp32_bias():
# casts model to fp16 -> int8 automatically