summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-10-24 11:54:25 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-10-24 11:54:25 -0700
commitdf86625a9399d16d6fb2e3bab6bb7bcc729f3b7d (patch)
tree34278a2cfd443d8e6f62aaba0f7a469db2807571 /tests/test_modules.py
parentb844e104b79ddc06161ff975aa93ffa9a7ec4801 (diff)
Isolated CUDASetup logging; all tests green.
Diffstat (limited to 'tests/test_modules.py')
-rw-r--r--tests/test_modules.py71
1 files changed, 0 insertions, 71 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 2879846..ccbf670 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -310,77 +310,6 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args)
-def test_linear8bit():
- l0 = torch.nn.Linear(32, 64).cuda().half()
- l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
- l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
- l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
-
- l0.weight.data = l2.weight.data.clone()
- l0.bias.data = l2.bias.data.clone()
-
- l1.weight.data = l2.weight.data.clone()
- l1.bias.data = l2.bias.data.clone()
-
- l3.weight.data = l2.weight.data.clone()
- l3.bias.data = l2.bias.data.clone()
-
- for i in range(100):
- b1 = torch.randn(16, 8, 32, device="cuda").half()
- t = torch.randn(16, 8, 64, device="cuda").half()
- b2 = b1.clone()
- b3 = b1.clone()
- b0 = b1.clone()
-
- o0 = l0(b0)
- o1 = l1(b1)
- o2 = l2(b2)
- o3 = l3(b3)
-
- assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
- assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)
-
- loss0 = torch.nn.functional.mse_loss(o0, t)
- loss1 = torch.nn.functional.mse_loss(o1, t)
- loss2 = torch.nn.functional.mse_loss(o2, t)
- loss3 = torch.nn.functional.mse_loss(o3, t)
-
- loss0.backward()
- loss1.backward()
- loss2.backward()
- loss3.backward()
-
- assert_all_approx_close(
- l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
- )
- assert_all_approx_close(
- l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
- )
- assert_all_approx_close(
- l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
- )
- assert_all_approx_close(
- l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
- )
-
- err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
- err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
- err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
-
- assert err1 * 0.8 < err2
- assert err2 * 0.8 < err3
- assert err3 * 0.8 < err1
-
- l0.weight.grad = None
- l1.weight.grad = None
- l2.weight.grad = None
- l3.weight.grad = None
- l0.bias.grad = None
- l1.bias.grad = None
- l2.bias.grad = None
- l3.bias.grad = None
-
-
threshold = [0.0, 3.0]
values = threshold
names = ["threshold_{0}".format(vals) for vals in values]