summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim/optimizer.py')
-rw-r--r--bitsandbytes/optim/optimizer.py77
1 files changed, 56 insertions, 21 deletions
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index b942e34..4fb30cd 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -46,9 +46,13 @@ class GlobalOptimManager(object):
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
- self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+ self.index2config[(group_index, p_index)] = self.pid2config[
+ id(p)
+ ]
- def override_config(self, parameters, key=None, value=None, key_value_dict=None):
+ def override_config(
+ self, parameters, key=None, value=None, key_value_dict=None
+ ):
"""
Overrides initial optimizer config for specific parameters.
@@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(groups) != len(saved_groups):
raise ValueError(
- "loaded state dict has a different number of " "parameter groups"
+ "loaded state dict has a different number of "
+ "parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
@@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"]
return new_group
- param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ param_groups = [
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)
+ ]
self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
@@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
- id(p)
- ]
+ self.mng.index2config[
+ (gindex, pindex)
+ ] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
@@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f"The update_step method needs to be overidden")
+ raise NotImplementedError(
+ f"The update_step method needs to be overidden"
+ )
class Optimizer2State(Optimizer8bit):
@@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
@@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
@@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
- self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
- self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
+ p.device
+ )
state["state1"] = torch.zeros_like(
p,
@@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device
)
else:
- state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
- state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
@@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
- unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ unorm_vec=state["unorm_vec"]
+ if config["max_unorm"] > 0.0
+ else None,
max_unorm=config["max_unorm"],
)
@@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
@@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p]
state["step"] = 0
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
@@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
- self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
state["state1"] = torch.zeros_like(
p,
@@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device
)
else:
- state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)