summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2022-07-22 14:41:05 -0700
commitc771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch)
tree158353d531766ed133be34d3c5085da6e8a4d01e /bitsandbytes/nn
parent4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff)
Most tests passing.
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/__init__.py2
-rw-r--r--bitsandbytes/nn/modules.py124
2 files changed, 123 insertions, 3 deletions
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py
index 27ad6ca..03b4655 100644
--- a/bitsandbytes/nn/__init__.py
+++ b/bitsandbytes/nn/__init__.py
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .modules import StableEmbedding, Embedding
+from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index c5460fb..5013d0b 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -3,14 +3,19 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
+import bitsandbytes as bnb
-from typing import Optional
+from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
-from torch import Tensor
+from torch import Tensor, device, dtype
+from torch import nn
+from torch.nn.parameter import Parameter
import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager
+T = TypeVar('T', bound='torch.nn.Module')
+
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
@@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
self.norm_type, self.scale_grad_by_freq, self.sparse)
return emb
+
+class Int8Params(torch.nn.Parameter):
+ def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
+ cls.has_fp16_weights = has_fp16_weights
+ cls.CB = None
+ cls.SCB = None
+ if data is None:
+ data = torch.empty(0)
+ return torch.Tensor._make_subclass(cls, data, requires_grad)
+
+ def cuda(self, device):
+ if self.has_fp16_weights:
+ return super().cuda(device)
+ else:
+ # we store the 8-bit rows-major weight
+ # we convert this weight to the turning/ampere weight during the first inference pass
+ B = self.data.contiguous().half().cuda(device)
+ CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
+ del CBt
+ del SCBt
+ self.data = CB
+ setattr(self, 'CB', CB)
+ setattr(self, 'SCB', SCB)
+
+ return self
+
+ @overload
+ def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
+ non_blocking: bool = ...) -> T:
+ ...
+
+ @overload
+ def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
+ ...
+
+ @overload
+ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
+ ...
+
+ def to(self, *args, **kwargs):
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
+
+ if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
+ else:
+ new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
+ new_param.CB = self.CB
+ new_param.SCB = self.SCB
+
+ return new_param
+
+
+
+class Linear8bitLt(nn.Linear):
+ def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
+ super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+ self.state = bnb.MatmulLtState()
+ self.index=index
+
+ self.state.threshold = threshold
+ self.state.has_fp16_weights = has_fp16_weights
+ if threshold > 0.0 and not has_fp16_weights:
+ self.state.use_pool = True
+
+ self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+
+ def init_8bit_state(self):
+ self.state.CB = self.weight.CB
+ self.state.SCB = self.weight.SCB
+ self.weight.CB = None
+ self.weight.SCB = None
+
+ def forward(self, x):
+ self.state.is_training = self.training
+
+ if self.weight.CB is not None: self.init_8bit_state()
+ #assert not self.state.has_fp16_weights
+ #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
+
+ out = bnb.matmul(x, self.weight, state=self.state)
+
+ if self.bias is not None:
+ out += self.bias.unsqueeze(0).expand_as(out)
+
+ if not self.state.has_fp16_weights and self.state.CB is not None:
+ # we converted 8-bit row major to turing/ampere format in the first inference pass
+ # we no longer need the row-major weight
+ del self.state.CB
+ self.weight.data = self.state.CxB
+
+ return out
+
+class Linear8bit(nn.Linear):
+ def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
+ super(Linear8bit, self).__init__(input_features, output_features, bias)
+ self.quant_type = quant_type
+ self.index = index
+ self.args = args
+ self.iter = 0
+
+ def forward(self, x):
+ self.iter += 1
+ if self.iter % self.args.clip_freq == 0:
+ with torch.no_grad():
+ maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ print('clip', maxval[-1].item())
+ self.weight.clip_(-maxval[-1], maxval[-1])
+
+
+ if self.args is not None:
+ out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
+ else:
+ out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
+
+ return out