summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/__init__.py8
-rw-r--r--bitsandbytes/nn/modules.py197
2 files changed, 151 insertions, 54 deletions
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 03b4655..98d4aa0 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
+from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 5013d0b..9ce3ac8 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -1,39 +1,59 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-import torch
-import bitsandbytes as bnb
+from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
+ Tuple, TypeVar, Union, overload)
-from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
-
-from torch import Tensor, device, dtype
-from torch import nn
-from torch.nn.parameter import Parameter
+import torch
import torch.nn.functional as F
+from torch import Tensor, device, dtype, nn
+from torch.nn.parameter import Parameter
+import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager
-T = TypeVar('T', bound='torch.nn.Module')
+T = TypeVar("T", bound="torch.nn.Module")
+
class StableEmbedding(torch.nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
- super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ _weight: Optional[Tensor] = None,
+ ) -> None:
+ super(StableEmbedding, self).__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ _weight,
+ )
self.norm = torch.nn.LayerNorm(embedding_dim)
- GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+ GlobalOptimManager.get_instance().register_module_override(
+ self, "weight", {"optim_bits": 32}
+ )
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
- ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
- '''
+ """
+
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
@@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return self.norm(emb)
class Embedding(torch.nn.Embedding):
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
- max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
- super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
- GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None,
+ norm_type: float = 2.0,
+ scale_grad_by_freq: bool = False,
+ sparse: bool = False,
+ _weight: Optional[Tensor] = None,
+ ) -> None:
+ super(Embedding, self).__init__(
+ num_embeddings,
+ embedding_dim,
+ padding_idx,
+ max_norm,
+ norm_type,
+ scale_grad_by_freq,
+ sparse,
+ _weight,
+ )
+ GlobalOptimManager.get_instance().register_module_override(
+ self, "weight", {"optim_bits": 32}
+ )
def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()
- ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
+ """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
- '''
+ """
+
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
@@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
- input, self.weight, self.padding_idx, self.max_norm,
- self.norm_type, self.scale_grad_by_freq, self.sparse)
+ input,
+ self.weight,
+ self.padding_idx,
+ self.max_norm,
+ self.norm_type,
+ self.scale_grad_by_freq,
+ self.sparse,
+ )
return emb
+
class Int8Params(torch.nn.Parameter):
- def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
+ def __new__(
+ cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
+ ):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
@@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del CBt
del SCBt
self.data = CB
- setattr(self, 'CB', CB)
- setattr(self, 'SCB', SCB)
+ setattr(self, "CB", CB)
+ setattr(self, "SCB", SCB)
return self
@overload
- def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
- non_blocking: bool = ...) -> T:
+ def to(
+ self: T,
+ device: Optional[Union[int, device]] = ...,
+ dtype: Optional[Union[dtype, str]] = ...,
+ non_blocking: bool = ...,
+ ) -> T:
...
@overload
@@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
...
def to(self, *args, **kwargs):
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
-
- if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
+ *args, **kwargs
+ )
+
+ if (
+ device is not None
+ and device.type == "cuda"
+ and self.data.device.type == "cpu"
+ ):
+ return self.cuda(device)
else:
- new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
+ new_param = Int8Params(
+ super().to(device=device, dtype=dtype, non_blocking=non_blocking),
+ requires_grad=self.requires_grad,
+ has_fp16_weights=self.has_fp16_weights,
+ )
new_param.CB = self.CB
new_param.SCB = self.SCB
return new_param
-
class Linear8bitLt(nn.Linear):
- def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
+ def __init__(
+ self,
+ input_features,
+ output_features,
+ bias=True,
+ has_fp16_weights=True,
+ threshold=0.0,
+ index=None,
+ ):
super(Linear8bitLt, self).__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState()
- self.index=index
+ self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
@@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def forward(self, x):
self.state.is_training = self.training
- if self.weight.CB is not None: self.init_8bit_state()
- #assert not self.state.has_fp16_weights
- #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
+ if self.weight.CB is not None:
+ self.init_8bit_state()
+ # assert not self.state.has_fp16_weights
+ # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, state=self.state)
@@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return out
+
class Linear8bit(nn.Linear):
- def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
+ def __init__(
+ self,
+ input_features,
+ output_features,
+ bias=True,
+ quant_type="vector",
+ index=None,
+ args=None,
+ sparse_decomp=False,
+ ):
super(Linear8bit, self).__init__(input_features, output_features, bias)
self.quant_type = quant_type
self.index = index
@@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self.iter += 1
if self.iter % self.args.clip_freq == 0:
with torch.no_grad():
- maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
+ maxval, maxidx = torch.topk(
+ torch.abs(self.weight.flatten()), k=self.args.clip_idx
+ )
if not dist.is_initialized() or dist.get_rank() == 0:
- print('clip', maxval[-1].item())
+ print("clip", maxval[-1].item())
self.weight.clip_(-maxval[-1], maxval[-1])
-
if self.args is not None:
- out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
+ out = bnb.nn.functional.sparse_decomposed_linear8bit(
+ x,
+ self.weight,
+ self.bias,
+ qval=self.args.sparse_decomp_val,
+ quant_type=self.args.quant_type,
+ )
else:
- out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
+ out = bnb.nn.functional.linear8bit(
+ x, self.weight, self.bias, quant_type=self.args.quant_type
+ )
return out