# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from typing import Optional from torch import Tensor from torch.nn.parameter import Parameter import torch.nn.functional as F from bitsandbytes.optim import GlobalOptimManager class StableEmbedding(torch.nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, sparse: bool = True, _weight: Optional[Tensor] = None) -> None: super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight) self.norm = torch.nn.LayerNorm(embedding_dim) GlobalOptimManager.get_instance().register_parameters(self.weight) GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. ''' def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): self.weight[self.padding_idx].fill_(0) def forward(self, input: Tensor) -> Tensor: emb = F.embedding( input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) return self.norm(emb)