From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- bitsandbytes/nn/modules.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) (limited to 'bitsandbytes/nn') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9ce3ac8..454dba5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,8 +2,19 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set, - Tuple, TypeVar, Union, overload) +from typing import ( + Any, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Set, + Tuple, + TypeVar, + Union, + overload, +) import torch import torch.nn.functional as F @@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding): class Int8Params(torch.nn.Parameter): def __new__( - cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None + cls, + data=None, + requires_grad=True, + has_fp16_weights=False, + CB=None, + SCB=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None @@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter): return self.cuda(device) else: new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), + super().to( + device=device, dtype=dtype, non_blocking=non_blocking + ), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear): threshold=0.0, index=None, ): - super(Linear8bitLt, self).__init__(input_features, output_features, bias) + super(Linear8bitLt, self).__init__( + input_features, output_features, bias + ) self.state = bnb.MatmulLtState() self.index = index @@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB -- cgit v1.2.3