summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-11-29 09:32:13 -0800
committerTim Dettmers <tim.dettmers@gmail.com>2021-11-29 09:32:13 -0800
commit20e1677dfdc4495038fd780807c8cbc253adf921 (patch)
tree42011169e55eab3f4226ff171d84edac84ec6f8f /bitsandbytes/nn
parent3cff6795fb70dd99b4802593f3c70d291e0cd1dc (diff)
Added module override, bnb.nn.Embedding #13 #15 #19
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/__init__.py2
-rw-r--r--bitsandbytes/nn/modules.py33
2 files changed, 32 insertions, 3 deletions
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 177540f..27ad6ca 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding
+from .modules import StableEmbedding, Embedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index ce2f3a4..dc0a171 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -18,8 +18,7 @@ class StableEmbedding(torch.nn.Embedding):
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
self.norm = torch.nn.LayerNorm(embedding_dim)
- GlobalOptimManager.get_instance().register_parameters(self.weight)
- GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
+ GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
@@ -42,3 +41,33 @@ class StableEmbedding(torch.nn.Embedding):
self.norm_type, self.scale_grad_by_freq, self.sparse)
return self.norm(emb)
+
+
+class Embedding(torch.nn.Embedding):
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+ sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
+ super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
+ GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+
+ def reset_parameters(self) -> None:
+ torch.nn.init.xavier_uniform_(self.weight)
+ self._fill_padding_idx_with_zero()
+
+ ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ to make the Layer compatible with Pytorch < 1.9.
+ This means that if this changes in future PyTorch releases this need to change too
+ which is cumbersome. However, with this we can ensure compatibility with previous
+ PyTorch releases.
+ '''
+ def _fill_padding_idx_with_zero(self) -> None:
+ if self.padding_idx is not None:
+ with torch.no_grad():
+ self.weight[self.padding_idx].fill_(0)
+
+ def forward(self, input: Tensor) -> Tensor:
+ emb = F.embedding(
+ input, self.weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+
+ return emb