diff options
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r-- | bitsandbytes/nn/__init__.py | 2 | ||||
-rw-r--r-- | bitsandbytes/nn/modules.py | 44 |
2 files changed, 1 insertions, 45 deletions
diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 98d4aa0..edc595a 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,4 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding +from .modules import Int8Params, Linear8bitLt, StableEmbedding diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9250fec..4f82cdc 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -271,47 +271,3 @@ class Linear8bitLt(nn.Linear): del self.state.CxB return out - - -class Linear8bit(nn.Linear): - def __init__( - self, - input_features, - output_features, - bias=True, - quant_type="vector", - index=None, - args=None, - sparse_decomp=False, - ): - super(Linear8bit, self).__init__(input_features, output_features, bias) - self.quant_type = quant_type - self.index = index - self.args = args - self.iter = 0 - - def forward(self, x): - self.iter += 1 - if self.iter % self.args.clip_freq == 0: - with torch.no_grad(): - maxval, maxidx = torch.topk( - torch.abs(self.weight.flatten()), k=self.args.clip_idx - ) - if not dist.is_initialized() or dist.get_rank() == 0: - print("clip", maxval[-1].item()) - self.weight.clip_(-maxval[-1], maxval[-1]) - - if self.args is not None: - out = bnb.nn.functional.sparse_decomposed_linear8bit( - x, - self.weight, - self.bias, - qval=self.args.sparse_decomp_val, - quant_type=self.args.quant_type, - ) - else: - out = bnb.nn.functional.linear8bit( - x, self.weight, self.bias, quant_type=self.args.quant_type - ) - - return out |