summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
commit7439924891496025edf60c9da6a782f362a50c70 (patch)
tree90476984d2c267f89232577a2ea40eb172387475 /bitsandbytes/nn
Initial commit
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/__init__.py5
-rw-r--r--bitsandbytes/nn/modules.py44
2 files changed, 49 insertions, 0 deletions
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
new file mode 100644
index 0000000..177540f
--- /dev/null
+++ b/bitsandbytes/nn/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .modules import StableEmbedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
new file mode 100644
index 0000000..bf0945c
--- /dev/null
+++ b/bitsandbytes/nn/modules.py
@@ -0,0 +1,44 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from typing import Optional
+
+from torch import Tensor
+from torch.nn.parameter import Parameter
+import torch.nn.functional as F
+
+from bitsandbytes.optim import GlobalOptimManager
+
+class StableEmbedding(torch.nn.Embedding):
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+ sparse: bool = True, _weight: Optional[Tensor] = None) -> None:
+ super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight)
+ self.norm = torch.nn.LayerNorm(embedding_dim)
+ GlobalOptimManager.get_instance().register_parameters(self.weight)
+ GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
+
+ def reset_parameters(self) -> None:
+ torch.nn.init.xavier_uniform_(self.weight)
+ self._fill_padding_idx_with_zero()
+
+ ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ to make the Layer compatible with Pytorch < 1.9.
+ This means that if this changes in future PyTorch releases this need to change too
+ which is cumbersome. However, with this we can ensure compatibility with previous
+ PyTorch releases.
+ '''
+ def _fill_padding_idx_with_zero(self) -> None:
+ if self.padding_idx is not None:
+ with torch.no_grad():
+ self.weight[self.padding_idx].fill_(0)
+
+ def forward(self, input: Tensor) -> Tensor:
+ emb = F.embedding(
+ input, self.weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+
+ return self.norm(emb)