diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-17 03:45:57 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-17 03:45:57 -0700 |
commit | 9d60b3c5279641ba936facd710c722ebe52fcf40 (patch) | |
tree | afc62a7ae35224acba5a03150623ef5a82830599 /tests | |
parent | b00cc9137fce41359318741106df92747aa14796 (diff) |
Fixed bug in Linear8bitLt, when the bias is None.
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_modules.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py index 7faadb8..c0b3311 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -549,3 +549,26 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" + + +def test_linear8bitlt_fp32_bias(): + # casts model to fp16 -> int8 automatically + l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() + assert l1.weight.dtype == torch.int8 + assert l1.bias.dtype == torch.float32 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + # casts bias to fp32 + o1 = l1(b1) + assert l1.bias.dtype == torch.float16 + + # casts model to fp16 -> int8 automatically + l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() + assert l1.weight.dtype == torch.int8 + assert l1.bias is None + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = l1(b1) + assert l1.bias is None |