summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py7
1 files changed, 2 insertions, 5 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index c6e7f85..01c9389 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -538,14 +538,11 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == "cuda"
- mlp = (
- MLP8bit(
+ mlp = MLP8bit(
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
)
- .to(torch.float16)
- .to("cuda")
- )
w1, w2 = mlp.fc1.weight.clone(), mlp.fc2.weight.clone()
+ mlp = mlp.cuda().half()
for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half()