summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/optimizer.py
diff options
context:
space:
mode:
authorTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
committerTitus von Koeller <titus@vonkoeller.com>2022-08-01 03:31:48 -0700
commitbfa0e33294f2b1dc25e65a33be2397f989824298 (patch)
tree396b5d722fdd79da068882ca7376e3636fcb3bb8 /bitsandbytes/optim/optimizer.py
parent597a8521b29e90958c31e47421016494da998648 (diff)
ran black and isort for coherent code formatting
Diffstat (limited to 'bitsandbytes/optim/optimizer.py')
-rw-r--r--bitsandbytes/optim/optimizer.py565
1 files changed, 380 insertions, 185 deletions
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 5a5bb1e..b942e34 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -1,13 +1,16 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from collections import abc as container_abcs
+from collections import defaultdict
+from copy import deepcopy
+from itertools import chain
+
import torch
+
import bitsandbytes.functional as F
-from copy import deepcopy
-from itertools import chain
-from collections import defaultdict, abc as container_abcs
class MockArgs(object):
def __init__(self, initial_data):
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.pid2config = {}
@@ -38,15 +41,15 @@ class GlobalOptimManager(object):
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
- param_groups = [{'params': param_groups}]
+ param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups):
- for p_index, p in enumerate(group['params']):
+ for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
- '''
+ """
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
@@ -63,7 +66,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
- '''
+ """
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
@@ -75,16 +78,16 @@ class GlobalOptimManager(object):
if key_value_dict is not None:
for p in parameters:
- if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
- else: self.pid2config[id(p)] = key_value_dict
+ if id(p) in self.pid2config:
+ self.pid2config[id(p)].update(key_value_dict)
+ else:
+ self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
-
class Optimizer8bit(torch.optim.Optimizer):
-
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
@@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
- ['qmap1', 'qmap2',
- 'max1', 'max2',
- 'new_max1', 'new_max2',
- 'state1', 'state2',
- 'gnorm_vec', 'absmax1', 'absmax2',
- 'unorm_vec'])
-
- if optim_bits == 8: self.fill_qmap()
+ [
+ "qmap1",
+ "qmap2",
+ "max1",
+ "max2",
+ "new_max1",
+ "new_max2",
+ "state1",
+ "state2",
+ "gnorm_vec",
+ "absmax1",
+ "absmax2",
+ "unorm_vec",
+ ]
+ )
+
+ if optim_bits == 8:
+ self.fill_qmap()
def fill_qmap(self):
- self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
- self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
+ self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
+ self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
-
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
- saved_groups = state_dict['param_groups']
+ saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
- raise ValueError("loaded state dict has a different number of "
- "parameter groups")
- param_lens = (len(g['params']) for g in groups)
- saved_lens = (len(g['params']) for g in saved_groups)
+ raise ValueError(
+ "loaded state dict has a different number of " "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
# Update the state
- id_map = {old_id: p for old_id, p in
- zip(chain.from_iterable((g['params'] for g in saved_groups)),
- chain.from_iterable((g['params'] for g in groups)))}
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
@@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
- for k, v in state_dict['state'].items():
+ for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
@@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
- new_group['params'] = group['params']
+ new_group["params"] = group["params"]
return new_group
- param_groups = [
- update_group(g, ng) for g, ng in zip(groups, saved_groups)]
- self.__setstate__({'state': state, 'param_groups': param_groups})
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p in self.state:
values = self.state[p]
for k, v in values.items():
@@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
- assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
+ assert isinstance(pmodule, torch.Tensor) or isinstance(
+ pmodule, torch.Parameter
+ )
found = False
for gindex, group in enumerate(self.param_groups):
- if found: break
- for pindex, p in enumerate(group['params']):
- if found: break
+ if found:
+ break
+ for pindex, p in enumerate(group["params"]):
+ if found:
+ break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
+ self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
+ id(p)
+ ]
found = True
@torch.no_grad()
@@ -219,11 +244,11 @@ class Optimizer8bit(torch.optim.Optimizer):
if not self.initialized:
self.check_overrides()
- self.to_gpu() # needed for fairseq pure fp16 training
+ self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
@@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
def get_config(self, gindex, pindex, group):
config = {}
- config['betas'] = group['betas']
- config['eps'] = group['eps']
- config['weight_decay'] = group['weight_decay']
- config['lr'] = group['lr']
- config['optim_bits'] = self.args.optim_bits
- config['min_8bit_size'] = self.args.min_8bit_size
- config['percentile_clipping'] = self.args.percentile_clipping
- config['block_wise'] = self.args.block_wise
- config['max_unorm'] = self.args.max_unorm
- config['skip_zeros'] = self.args.skip_zeros
+ config["betas"] = group["betas"]
+ config["eps"] = group["eps"]
+ config["weight_decay"] = group["weight_decay"]
+ config["lr"] = group["lr"]
+ config["optim_bits"] = self.args.optim_bits
+ config["min_8bit_size"] = self.args.min_8bit_size
+ config["percentile_clipping"] = self.args.percentile_clipping
+ config["block_wise"] = self.args.block_wise
+ config["max_unorm"] = self.args.max_unorm
+ config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
- raise NotImplementedError(f'init_state method needs to be overidden')
+ raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f'The update_step method needs to be overidden')
+ raise NotImplementedError(f"The update_step method needs to be overidden")
+
class Optimizer2State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
# format: '(beta1, beta2)'
- betas = betas.replace('(', '').replace(')', '').strip().split(',')
+ betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
+ state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
- self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap2'] = self.name2qmap['udynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap2"] = self.name2qmap["udynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
- state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
+ state["absmax2"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["new_max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
- config['weight_decay'], gnorm_scale=gnorm_scale,
- unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ state["state2"],
+ config["betas"][1],
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["max1"],
+ state["max2"],
+ state["new_max1"],
+ state["new_max2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ )
# swap maxes
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- state['max2'], state['new_max2'] = state['new_max2'], state['max2']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["absmax1"],
+ state["absmax2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
class Optimizer1State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.0),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
+ state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
-
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+ state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- None, 0.0, config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
- skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
- config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
-
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ None,
+ 0.0,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["max1"],
+ None,
+ state["new_max1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ )
+
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["absmax1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )