summaryrefslogtreecommitdiff
path: root/bitsandbytes/nn
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/nn')
-rw-r--r--bitsandbytes/nn/modules.py34
1 files changed, 28 insertions, 6 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index 9ce3ac8..454dba5 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -2,8 +2,19 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
- Tuple, TypeVar, Union, overload)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import torch
import torch.nn.functional as F
@@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
class Int8Params(torch.nn.Parameter):
def __new__(
- cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
+ cls,
+ data=None,
+ requires_grad=True,
+ has_fp16_weights=False,
+ CB=None,
+ SCB=None,
):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
@@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
return self.cuda(device)
else:
new_param = Int8Params(
- super().to(device=device, dtype=dtype, non_blocking=non_blocking),
+ super().to(
+ device=device, dtype=dtype, non_blocking=non_blocking
+ ),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
@@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
threshold=0.0,
index=None,
):
- super(Linear8bitLt, self).__init__(input_features, output_features, bias)
+ super(Linear8bitLt, self).__init__(
+ input_features, output_features, bias
+ )
self.state = bnb.MatmulLtState()
self.index = index
@@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
- self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
+ self.weight = Int8Params(
+ self.weight.data, has_fp16_weights=has_fp16_weights
+ )
def init_8bit_state(self):
self.state.CB = self.weight.CB