summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/adam.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim/adam.py')
-rw-r--r--bitsandbytes/optim/adam.py27
1 files changed, 21 insertions, 6 deletions
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index 5cfaa28..3634971 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
savedir=None,
):
defaults = dict(
- lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
)
super(AnalysisAdam, self).__init__(params, defaults)
self.analysis = bnb_analysis
@@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
- state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["counts"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
- step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ step_size = (
+ group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ )
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
@@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
+ if (
+ p_data_fp32.numel() <= 8192
+ or p_data_fp32.numel() > 50000 * 1000
+ ):
# embedding layer or too small
p_data_fp32 += -step_size * update_fp32
else:
@@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f"Invalid analysis value: {self.analysis}!")
+ raise ValueError(
+ f"Invalid analysis value: {self.analysis}!"
+ )
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom
@@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
- shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
+ shapestr = "_".join(
+ [str(dim) for dim in p_data_fp32.shape]
+ )
pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)