summaryrefslogtreecommitdiff
path: root/tests/test_modules.py
blob: a0379cb5273276bbba44b68c5b2dfb4e387b707d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) Facebook, Inc. and its affiliates. 
#   
# This source code is licensed under the MIT license found in the 
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import bitsandbytes as bnb


@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
def test_embeddings(embcls):
    bnb.optim.GlobalOptimManager.get_instance().initialize()
    emb1 = torch.nn.Embedding(100, 512).cuda()
    emb2 = embcls(100, 512).cuda()

    adam1 = bnb.optim.Adam8bit(emb1.parameters())
    adam2 = bnb.optim.Adam8bit(emb2.parameters())

    batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()

    for i in range(100):
        batch = batches[i]

        embedded1 = emb1(batch)
        embedded2 = emb2(batch)

        l1 = embedded1.mean()
        l2 = embedded2.mean()

        l1.backward()
        l2.backward()

        adam1.step()
        adam2.step()

        adam1.zero_grad()
        adam2.zero_grad()

        assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
        assert adam2.state[emb2.weight]['state1'].dtype == torch.float32