diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2021-10-05 19:16:20 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2021-10-05 19:16:20 -0700 |
commit | 7439924891496025edf60c9da6a782f362a50c70 (patch) | |
tree | 90476984d2c267f89232577a2ea40eb172387475 /bitsandbytes/nn |
Initial commit
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r-- | bitsandbytes/nn/__init__.py | 5 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 44 |
2 files changed, 49 insertions, 0 deletions
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py new file mode 100644 index 0000000..177540f --- /dev/null +++ b/bitsandbytes/nn/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .modules import StableEmbedding diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py new file mode 100644 index 0000000..bf0945c --- /dev/null +++ b/bitsandbytes/nn/modules.py @@ -0,0 +1,44 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from typing import Optional + +from torch import Tensor +from torch.nn.parameter import Parameter +import torch.nn.functional as F + +from bitsandbytes.optim import GlobalOptimManager + +class StableEmbedding(torch.nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = True, _weight: Optional[Tensor] = None) -> None: + super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight) + self.norm = torch.nn.LayerNorm(embedding_dim) + GlobalOptimManager.get_instance().register_parameters(self.weight) + GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + ''' + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + return self.norm(emb) |