diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2021-11-29 09:32:13 -0800 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2021-11-29 09:32:13 -0800 |
commit | 20e1677dfdc4495038fd780807c8cbc253adf921 (patch) | |
tree | 42011169e55eab3f4226ff171d84edac84ec6f8f /tests | |
parent | 3cff6795fb70dd99b4802593f3c70d291e0cd1dc (diff) |
Added module override, bnb.nn.Embedding #13 #15 #19
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_modules.py | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 0000000..6cbee7b --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import bitsandbytes as bnb + +from itertools import product + +from bitsandbytes import functional as F + + +@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) +def test_embeddings(embcls): + bnb.optim.GlobalOptimManager.get_instance().initialize() + emb1 = torch.nn.Embedding(100, 512).cuda() + emb2 = embcls(100, 512).cuda() + + adam1 = bnb.optim.Adam8bit(emb1.parameters()) + adam2 = bnb.optim.Adam8bit(emb2.parameters()) + + batches = torch.randint(1, 100, size=(100, 4, 32)).cuda() + + for i in range(100): + batch = batches[i] + + embedded1 = emb1(batch) + embedded2 = emb2(batch) + + l1 = embedded1.mean() + l2 = embedded2.mean() + + l1.backward() + l2.backward() + + adam1.step() + adam2.step() + + adam1.zero_grad() + adam2.zero_grad() + + assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 + assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 + + |