summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_modules.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 53a675f..d3992a9 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
+
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
+ if memory_efficient_backward:
+ b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half)
+ o1 = mlp(b1)
+ assert o1.dtype == torch.float16
+ assert o1.requires_grad
+ grad_proj = torch.randn_like(o1)
+
+ (o1 * grad_proj).sum().backward()
+
+
def test_linear8bitlt_fp32_bias():