summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/adagrad.py12
-rw-r--r--bitsandbytes/optim/adam.py27
-rw-r--r--bitsandbytes/optim/lars.py20
-rw-r--r--bitsandbytes/optim/optimizer.py77
-rw-r--r--bitsandbytes/optim/rmsprop.py12
5 files changed, 110 insertions, 38 deletions
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
index 43e3973..7e2f566 100644
--- a/bitsandbytes/optim/adagrad.py
+++ b/bitsandbytes/optim/adagrad.py
@@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
@@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
@@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index 5cfaa28..3634971 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
savedir=None,
):
defaults = dict(
- lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
)
super(AnalysisAdam, self).__init__(params, defaults)
self.analysis = bnb_analysis
@@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
- state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["counts"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
- step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ step_size = (
+ group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ )
e = state["abserrors"]
rele = state["relerrors"]
counts = state["counts"]
@@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
+ if (
+ p_data_fp32.numel() <= 8192
+ or p_data_fp32.numel() > 50000 * 1000
+ ):
# embedding layer or too small
p_data_fp32 += -step_size * update_fp32
else:
@@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f"Invalid analysis value: {self.analysis}!")
+ raise ValueError(
+ f"Invalid analysis value: {self.analysis}!"
+ )
denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom
@@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir):
os.makedirs(self.savedir)
- shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
+ shapestr = "_".join(
+ [str(dim) for dim in p_data_fp32.shape]
+ )
pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index c6cf5c6..8a89fb0 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -24,7 +24,9 @@ class LARS(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS, self).__init__(
"lars",
params,
@@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS8bit, self).__init__(
"lars",
params,
@@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS32bit, self).__init__(
"lars",
params,
@@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(
lr=lr,
@@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ raise ValueError(
+ "Nesterov momentum requires a momentum and zero dampening"
+ )
super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state):
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index b942e34..4fb30cd 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -46,9 +46,13 @@ class GlobalOptimManager(object):
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
- self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+ self.index2config[(group_index, p_index)] = self.pid2config[
+ id(p)
+ ]
- def override_config(self, parameters, key=None, value=None, key_value_dict=None):
+ def override_config(
+ self, parameters, key=None, value=None, key_value_dict=None
+ ):
"""
Overrides initial optimizer config for specific parameters.
@@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(groups) != len(saved_groups):
raise ValueError(
- "loaded state dict has a different number of " "parameter groups"
+ "loaded state dict has a different number of "
+ "parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
@@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"]
return new_group
- param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ param_groups = [
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)
+ ]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
@@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
- id(p)
- ]
+ self.mng.index2config[
+ (gindex, pindex)
+ ] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
@@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f"The update_step method needs to be overidden")
+ raise NotImplementedError(
+ f"The update_step method needs to be overidden"
+ )
class Optimizer2State(Optimizer8bit):
@@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
@@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
@@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
- self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
- self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
+ p.device
+ )
state["state1"] = torch.zeros_like(
p,
@@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device
)
else:
- state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
- state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
@@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
- unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ unorm_vec=state["unorm_vec"]
+ if config["max_unorm"] > 0.0
+ else None,
max_unorm=config["max_unorm"],
)
@@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
@@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
@@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
- self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
state["state1"] = torch.zeros_like(
p,
@@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device
)
else:
- state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index 679f783..7ddb12c 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
block_wise=True,
):
if alpha == 0:
- raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__(
@@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
block_wise=True,
):
if alpha == 0:
- raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__(
@@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
):
if alpha == 0:
- raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__(