summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn/modules.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/nn/modules.py')
-rw-r--r--bitsandbytes/nn/modules.py44
1 files changed, 0 insertions, 44 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 9250fec..4f82cdc 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -271,47 +271,3 @@ class Linear8bitLt(nn.Linear):
del self.state.CxB
return out
-
-
-class Linear8bit(nn.Linear):
- def __init__(
- self,
- input_features,
- output_features,
- bias=True,
- quant_type="vector",
- index=None,
- args=None,
- sparse_decomp=False,
- ):
- super(Linear8bit, self).__init__(input_features, output_features, bias)
- self.quant_type = quant_type
- self.index = index
- self.args = args
- self.iter = 0
-
- def forward(self, x):
- self.iter += 1
- if self.iter % self.args.clip_freq == 0:
- with torch.no_grad():
- maxval, maxidx = torch.topk(
- torch.abs(self.weight.flatten()), k=self.args.clip_idx
- )
- if not dist.is_initialized() or dist.get_rank() == 0:
- print("clip", maxval[-1].item())
- self.weight.clip_(-maxval[-1], maxval[-1])
-
- if self.args is not None:
- out = bnb.nn.functional.sparse_decomposed_linear8bit(
- x,
- self.weight,
- self.bias,
- qval=self.args.sparse_decomp_val,
- quant_type=self.args.quant_type,
- )
- else:
- out = bnb.nn.functional.linear8bit(
- x, self.weight, self.bias, quant_type=self.args.quant_type
- )
-
- return out