From 20e1677dfdc4495038fd780807c8cbc253adf921 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 29 Nov 2021 09:32:13 -0800 Subject: Added module override, bnb.nn.Embedding #13 #15 #19 --- tests/test_modules.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_modules.py (limited to 'tests') diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 0000000..6cbee7b --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest +import torch +import bitsandbytes as bnb + +from itertools import product + +from bitsandbytes import functional as F + + +@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) +def test_embeddings(embcls): + bnb.optim.GlobalOptimManager.get_instance().initialize() + emb1 = torch.nn.Embedding(100, 512).cuda() + emb2 = embcls(100, 512).cuda() + + adam1 = bnb.optim.Adam8bit(emb1.parameters()) + adam2 = bnb.optim.Adam8bit(emb2.parameters()) + + batches = torch.randint(1, 100, size=(100, 4, 32)).cuda() + + for i in range(100): + batch = batches[i] + + embedded1 = emb1(batch) + embedded2 = emb2(batch) + + l1 = embedded1.mean() + l2 = embedded2.mean() + + l1.backward() + l2.backward() + + adam1.step() + adam2.step() + + adam1.zero_grad() + adam2.zero_grad() + + assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 + assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 + + -- cgit v1.2.3