summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn/modules.py
diff options
context:
space:
mode:
authorTim Dettmers <TimDettmers@users.noreply.github.com>2021-11-29 08:22:16 -0800
committerGitHub <noreply@github.com>2021-11-29 08:22:16 -0800
commit262350c10f4174d0c775b61be2dbf526afa69cd2 (patch)
tree3b46cf5ac603ba1342565390294fc3c296cbc916 /bitsandbytes/nn/modules.py
parent037022e878974b5dfeb354098a467a46618f9d85 (diff)
parent67a1283501fa24d346f8e8efb4fc888a9ed8d193 (diff)
Merge pull request #14 from SirRob1997/main
[FIX] passing of sparse in StableEmbedding
Diffstat (limited to 'bitsandbytes/nn/modules.py')
-rw-r--r--bitsandbytes/nn/modules.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index bf0945c..ce2f3a4 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -15,8 +15,8 @@ from bitsandbytes.optim import GlobalOptimManager
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = True, _weight: Optional[Tensor] = None) -> None:
- super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight)
+ sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
+ super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_parameters(self.weight)
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)