From df86625a9399d16d6fb2e3bab6bb7bcc729f3b7d Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 24 Oct 2022 11:54:25 -0700 Subject: Isolated CUDASetup logging; all tests green. --- tests/test_modules.py | 71 --------------------------------------------------- 1 file changed, 71 deletions(-) (limited to 'tests/test_modules.py') diff --git a/tests/test_modules.py b/tests/test_modules.py index 2879846..ccbf670 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -310,77 +310,6 @@ class Linear8bit(nn.Module): return LinearFunction.apply(x, self.weight, self.bias, self.args) -def test_linear8bit(): - l0 = torch.nn.Linear(32, 64).cuda().half() - l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half() - l2 = Linear8bit(32, 64, args=get_args()).cuda().half() - l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half() - - l0.weight.data = l2.weight.data.clone() - l0.bias.data = l2.bias.data.clone() - - l1.weight.data = l2.weight.data.clone() - l1.bias.data = l2.bias.data.clone() - - l3.weight.data = l2.weight.data.clone() - l3.bias.data = l2.bias.data.clone() - - for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() - t = torch.randn(16, 8, 64, device="cuda").half() - b2 = b1.clone() - b3 = b1.clone() - b0 = b1.clone() - - o0 = l0(b0) - o1 = l1(b1) - o2 = l2(b2) - o3 = l3(b3) - - assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1) - assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1) - - loss0 = torch.nn.functional.mse_loss(o0, t) - loss1 = torch.nn.functional.mse_loss(o1, t) - loss2 = torch.nn.functional.mse_loss(o2, t) - loss3 = torch.nn.functional.mse_loss(o3, t) - - loss0.backward() - loss1.backward() - loss2.backward() - loss3.backward() - - assert_all_approx_close( - l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2 - ) - assert_all_approx_close( - l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2 - ) - assert_all_approx_close( - l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 - ) - assert_all_approx_close( - l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 - ) - - err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item() - err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item() - err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item() - - assert err1 * 0.8 < err2 - assert err2 * 0.8 < err3 - assert err3 * 0.8 < err1 - - l0.weight.grad = None - l1.weight.grad = None - l2.weight.grad = None - l3.weight.grad = None - l0.bias.grad = None - l1.bias.grad = None - l2.bias.grad = None - l3.bias.grad = None - - threshold = [0.0, 3.0] values = threshold names = ["threshold_{0}".format(vals) for vals in values] -- cgit v1.2.3