From ea7c14f8ef64924f2d0ff80df3cdabf2c7299848 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 09:32:47 -0700 Subject: reran black with linelength 80 for greater readability --- bitsandbytes/optim/adam.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'bitsandbytes/optim/adam.py') diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 5cfaa28..3634971 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer): savedir=None, ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, ) super(AnalysisAdam, self).__init__(params, defaults) self.analysis = bnb_analysis @@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer): state["relerrors"] = torch.zeros( (256, 256), device=p_data_fp32.device ) - state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer): beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] @@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f"Invalid analysis value: {self.analysis}!") + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) pathe = os.path.join( self.savedir, f"{p_id}_{shapestr}_abserr.pkl" ) -- cgit v1.2.3