From 20e1677dfdc4495038fd780807c8cbc253adf921 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Mon, 29 Nov 2021 09:32:13 -0800 Subject: Added module override, bnb.nn.Embedding #13 #15 #19 --- bitsandbytes/nn/__init__.py | 2 +- bitsandbytes/nn/modules.py | 33 +++++++++++++++++++++++++++++++-- bitsandbytes/optim/optimizer.py | 31 ++++++++++++++++++++++++++++--- 3 files changed, 60 insertions(+), 6 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 177540f..27ad6ca 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import StableEmbedding +from .modules import StableEmbedding, Embedding diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ce2f3a4..dc0a171 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -18,8 +18,7 @@ class StableEmbedding(torch.nn.Embedding): sparse: bool = False, _weight: Optional[Tensor] = None) -> None: super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) self.norm = torch.nn.LayerNorm(embedding_dim) - GlobalOptimManager.get_instance().register_parameters(self.weight) - GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32) + GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) @@ -42,3 +41,33 @@ class StableEmbedding(torch.nn.Embedding): self.norm_type, self.scale_grad_by_freq, self.sparse) return self.norm(emb) + + +class Embedding(torch.nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[Tensor] = None) -> None: + super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) + GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + ''' + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + return emb diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index cfbd72e..5a5bb1e 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -26,6 +26,7 @@ class GlobalOptimManager(object): self.index2config = {} self.optimizer = None self.uses_config_override = False + self.module_weight_config_triple = [] @classmethod def get_instance(cls): @@ -77,12 +78,16 @@ class GlobalOptimManager(object): if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) else: self.pid2config[id(p)] = key_value_dict + def register_module_override(self, module, param_name, config): + self.module_weight_config_triple.append((module, param_name, config)) + + class Optimizer8bit(torch.optim.Optimizer): def __init__(self, params, defaults, optim_bits=32): super(Optimizer8bit, self).__init__(params, defaults) - self.checked_if_on_gpu = False + self.initialized = False self.name2qmap = {} self.mng = GlobalOptimManager.get_instance() @@ -172,7 +177,6 @@ class Optimizer8bit(torch.optim.Optimizer): self.__setstate__({'state': state, 'param_groups': param_groups}) def to_gpu(self): - self.checked_if_on_gpu = True for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group['params']): if p in self.state: @@ -181,6 +185,23 @@ class Optimizer8bit(torch.optim.Optimizer): if isinstance(v, torch.Tensor): self.state[p][k] = v.to(p.device) + def check_overrides(self): + for module, attr, config in self.mng.module_weight_config_triple: + pmodule = getattr(module, attr) + assert pmodule is not None + assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + found = False + for gindex, group in enumerate(self.param_groups): + if found: break + for pindex, p in enumerate(group['params']): + if found: break + if id(p) == id(pmodule): + # found the matching parameter + # init override + self.mng.pid2config[id(p)] = config + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] + found = True + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -196,7 +217,11 @@ class Optimizer8bit(torch.optim.Optimizer): overflows = [] - if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group['params']): if p.grad is None: -- cgit v1.2.3