summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-11-29 09:32:13 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2021-11-29 09:32:13 -0800
commit20e1677dfdc4495038fd780807c8cbc253adf921 (patch)
tree42011169e55eab3f4226ff171d84edac84ec6f8f /tests
parent3cff6795fb70dd99b4802593f3c70d291e0cd1dc (diff)
Added module override, bnb.nn.Embedding #13 #15 #19
Diffstat (limited to 'tests')
-rw-r--r--tests/test_modules.py46
1 files changed, 46 insertions, 0 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py
new file mode 100644
index 0000000..6cbee7b
--- /dev/null
+++ b/tests/test_modules.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import pytest
+import torch
+import bitsandbytes as bnb
+
+from itertools import product
+
+from bitsandbytes import functional as F
+
+
+@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
+def test_embeddings(embcls):
+ bnb.optim.GlobalOptimManager.get_instance().initialize()
+ emb1 = torch.nn.Embedding(100, 512).cuda()
+ emb2 = embcls(100, 512).cuda()
+
+ adam1 = bnb.optim.Adam8bit(emb1.parameters())
+ adam2 = bnb.optim.Adam8bit(emb2.parameters())
+
+ batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
+
+ for i in range(100):
+ batch = batches[i]
+
+ embedded1 = emb1(batch)
+ embedded2 = emb2(batch)
+
+ l1 = embedded1.mean()
+ l2 = embedded2.mean()
+
+ l1.backward()
+ l2.backward()
+
+ adam1.step()
+ adam2.step()
+
+ adam1.zero_grad()
+ adam2.zero_grad()
+
+ assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
+ assert adam2.state[emb2.weight]['state1'].dtype == torch.float32
+
+