diff options
Diffstat (limited to 'bitsandbytes/nn/modules.py')
-rw-r--r-- | bitsandbytes/nn/modules.py | 197 |
1 files changed, 147 insertions, 50 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5013d0b..9ce3ac8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1,39 +1,59 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch -import bitsandbytes as bnb +from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set, + Tuple, TypeVar, Union, overload) -from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict - -from torch import Tensor, device, dtype -from torch import nn -from torch.nn.parameter import Parameter +import torch import torch.nn.functional as F +from torch import Tensor, device, dtype, nn +from torch.nn.parameter import Parameter +import bitsandbytes as bnb from bitsandbytes.optim import GlobalOptimManager -T = TypeVar('T', bound='torch.nn.Module') +T = TypeVar("T", bound="torch.nn.Module") + class StableEmbedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None) -> None: - super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(StableEmbedding, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) self.norm = torch.nn.LayerNorm(embedding_dim) - GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. - ''' + """ + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): @@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding): def forward(self, input: Tensor) -> Tensor: emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return self.norm(emb) class Embedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None) -> None: - super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) - GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(Embedding, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. - ''' + """ + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): @@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding): def forward(self, input: Tensor) -> Tensor: emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return emb + class Int8Params(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None): + def __new__( + cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None + ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None @@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter): del CBt del SCBt self.data = CB - setattr(self, 'CB', CB) - setattr(self, 'SCB', SCB) + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) return self @overload - def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., - non_blocking: bool = ...) -> T: + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: ... @overload @@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter): ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - - if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + + if ( + device is not None + and device.type == "cuda" + and self.data.device.type == "cpu" + ): + return self.cuda(device) else: - new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights) + new_param = Int8Params( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) new_param.CB = self.CB new_param.SCB = self.SCB return new_param - class Linear8bitLt(nn.Linear): - def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + threshold=0.0, + index=None, + ): super(Linear8bitLt, self).__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() - self.index=index + self.index = index self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights @@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear): def forward(self, x): self.state.is_training = self.training - if self.weight.CB is not None: self.init_8bit_state() - #assert not self.state.has_fp16_weights - #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None + if self.weight.CB is not None: + self.init_8bit_state() + # assert not self.state.has_fp16_weights + # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None out = bnb.matmul(x, self.weight, state=self.state) @@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear): return out + class Linear8bit(nn.Linear): - def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False): + def __init__( + self, + input_features, + output_features, + bias=True, + quant_type="vector", + index=None, + args=None, + sparse_decomp=False, + ): super(Linear8bit, self).__init__(input_features, output_features, bias) self.quant_type = quant_type self.index = index @@ -178,15 +266,24 @@ class Linear8bit(nn.Linear): self.iter += 1 if self.iter % self.args.clip_freq == 0: with torch.no_grad(): - maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx) + maxval, maxidx = torch.topk( + torch.abs(self.weight.flatten()), k=self.args.clip_idx + ) if not dist.is_initialized() or dist.get_rank() == 0: - print('clip', maxval[-1].item()) + print("clip", maxval[-1].item()) self.weight.clip_(-maxval[-1], maxval[-1]) - if self.args is not None: - out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type) + out = bnb.nn.functional.sparse_decomposed_linear8bit( + x, + self.weight, + self.bias, + qval=self.args.sparse_decomp_val, + quant_type=self.args.quant_type, + ) else: - out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type) + out = bnb.nn.functional.linear8bit( + x, self.weight, self.bias, quant_type=self.args.quant_type + ) return out |