summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
commit7439924891496025edf60c9da6a782f362a50c70 (patch)
tree90476984d2c267f89232577a2ea40eb172387475 /bitsandbytes/optim
Initial commit
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/__init__.py10
-rw-r--r--bitsandbytes/optim/adam.py28
-rw-r--r--bitsandbytes/optim/lamb.py29
-rw-r--r--bitsandbytes/optim/lars.py115
-rw-r--r--bitsandbytes/optim/optimizer.py460
-rw-r--r--bitsandbytes/optim/rmsprop.py37
-rw-r--r--bitsandbytes/optim/sgd.py32
7 files changed, 711 insertions, 0 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
new file mode 100644
index 0000000..92c83b1
--- /dev/null
+++ b/bitsandbytes/optim/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from .adam import Adam, Adam8bit, Adam32bit
+from .sgd import SGD, SGD8bit, SGD32bit
+from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
+from .lamb import LAMB, LAMB8bit, LAMB32bit
+from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
+from .optimizer import GlobalOptimManager
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
new file mode 100644
index 0000000..99a6d10
--- /dev/null
+++ b/bitsandbytes/optim/adam.py
@@ -0,0 +1,28 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from bitsandbytes.optim.optimizer import Optimizer2State
+
+class Adam(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(Adam, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adam8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class Adam32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+
+
diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py
new file mode 100644
index 0000000..b8d4b1e
--- /dev/null
+++ b/bitsandbytes/optim/lamb.py
@@ -0,0 +1,29 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import apex
+from bitsandbytes.optim.optimizer import Optimizer2State
+
+class LAMB(Optimizer2State):
+ def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
+ super(LAMB, self).__init__('lamb', params, lr, betas, eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+
+class LAMB8bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
+ super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+
+class LAMB32bit(Optimizer2State):
+ def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
+ super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+
+
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
new file mode 100644
index 0000000..40dede7
--- /dev/null
+++ b/bitsandbytes/optim/lars.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from torch.optim import Optimizer
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class LARS(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+class LARS8bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+class LARS32bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+
+class PytorchLARS(Optimizer):
+ def __init__(self, params, lr=0.01, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, max_unorm=0.02):
+ if lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if momentum < 0.0:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if weight_decay < 0.0:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
+ weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ super(PytorchLARS, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(PytorchLARS, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('nesterov', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ d_p_list = []
+ momentum_buffer_list = []
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+ max_unorm = group['max_unorm']
+ lr = group['lr']
+
+ for p in group['params']:
+ if p.grad is None: continue
+
+ state = self.state[p]
+ d_p = p.grad
+ if weight_decay != 0:
+ d_p = d_p.add(param, alpha=weight_decay)
+
+ if momentum != 0:
+ buf = state.get('momentum_buffer', None)
+
+ if buf is None:
+ buf = torch.clone(d_p).detach()
+ state['momentum_buffer']= buf
+ else:
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
+
+ if nesterov:
+ update = d_p + buf*momentum
+ else:
+ update = buf
+
+ update_scale = 1.0
+ if max_unorm > 0.0:
+ assert p.dtype == torch.float32
+ pnorm = torch.norm(p.detach())
+ unorm = torch.norm(update)
+ if unorm > max_unorm*pnorm:
+ update_scale = max_unorm*pnorm/unorm
+
+ p.add_(update, alpha=-lr*update_scale)
+
+ return loss
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
new file mode 100644
index 0000000..6743c15
--- /dev/null
+++ b/bitsandbytes/optim/optimizer.py
@@ -0,0 +1,460 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+import bitsandbytes.functional as F
+
+from copy import deepcopy
+from itertools import chain
+from collections import defaultdict, abc as container_abcs
+
+class MockArgs(object):
+ def __init__(self, initial_data):
+ for key in initial_data:
+ setattr(self, key, initial_data[key])
+
+
+class GlobalOptimManager(object):
+ _instance = None
+
+ def __init__(self):
+ raise RuntimeError('Call get_instance() instead')
+
+ def initialize(self):
+ self.pid2config = {}
+ self.index2config = {}
+ self.optimizer = None
+ self.uses_config_override = False
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls.__new__(cls)
+ cls._instance.initialize()
+ return cls._instance
+
+ def register_parameters(self, params):
+ param_groups = list(params)
+ if not isinstance(param_groups[0], dict):
+ param_groups = [{'params': param_groups}]
+
+ for group_index, group in enumerate(param_groups):
+ for p_index, p in enumerate(group['params']):
+ if id(p) in self.pid2config:
+ self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+
+ def override_config(self, parameters, key=None, value=None, key_value_dict=None):
+ '''
+ Overrides initial optimizer config for specific parameters.
+
+ The key-values of the optimizer config for the input parameters are overidden
+ This can be both, optimizer parameters like "betas", or "lr" or it can be
+ 8-bit specific paramters like "optim_bits", "percentile_clipping".
+
+ Parameters
+ ----------
+ parameters : torch.Tensor or list(torch.Tensors)
+ The input parameters.
+ key : str
+ The hyperparamter to override.
+ value : object
+ The value for the hyperparamters.
+ key_value_dict : dict
+ A dictionary with multiple key-values to override.
+ '''
+ self.uses_config_override = True
+ if isinstance(parameters, torch.nn.Parameter):
+ parameters = [parameters]
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ if key is not None and value is not None:
+ assert key_value_dict is None
+ key_value_dict = {key: value}
+
+ if key_value_dict is not None:
+ for p in parameters:
+ if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
+ else: self.pid2config[id(p)] = key_value_dict
+
+
+class Optimizer8bit(torch.optim.Optimizer):
+
+ def __init__(self, params, defaults, optim_bits=32):
+ super(Optimizer8bit, self).__init__(params, defaults)
+ self.checked_if_on_gpu = False
+ self.name2qmap = {}
+
+ self.mng = GlobalOptimManager.get_instance()
+ self.non_castable_tensor_keys = set(
+ ['qmap1', 'qmap2',
+ 'max1', 'max2',
+ 'new_max1', 'new_max2',
+ 'state1', 'state2',
+ 'gnorm_vec', 'absmax1', 'absmax2',
+ 'unorm_vec'])
+
+ if optim_bits == 8: self.fill_qmap()
+
+ def fill_qmap(self):
+ self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
+ self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
+
+ def __setstate__(self, state):
+ super(Optimizer8bit, self).__setstate__(state)
+
+
+ def load_state_dict(self, state_dict):
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict['param_groups']
+
+ if len(groups) != len(saved_groups):
+ raise ValueError("loaded state dict has a different number of "
+ "parameter groups")
+ param_lens = (len(g['params']) for g in groups)
+ saved_lens = (len(g['params']) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError("loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group")
+
+ # Update the state
+ id_map = {old_id: p for old_id, p in
+ zip(chain.from_iterable((g['params'] for g in saved_groups)),
+ chain.from_iterable((g['params'] for g in groups)))}
+
+ def cast(param, value):
+ r"""Make a deep copy of value, casting all tensors to device of param."""
+ if isinstance(value, torch.Tensor):
+ # Floating-point types are a bit special here. They are the only ones
+ # that are assumed to always match the type of params.
+ if param.is_floating_point() and value.dtype != torch.uint8:
+ value = value.to(param.dtype)
+ return value
+ elif isinstance(value, dict):
+ for k, v in value.items():
+ if k in self.non_castable_tensor_keys:
+ value[k] = v.to(param.device)
+ else:
+ value[k] = cast(param, v)
+
+ return value
+ elif isinstance(value, container_abcs.Iterable):
+ return type(value)(cast(param, v) for v in value)
+ else:
+ return value
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ for k, v in state_dict['state'].items():
+ if k in id_map:
+ param = id_map[k]
+ state[param] = cast(param, v)
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group['params'] = group['params']
+ return new_group
+ param_groups = [
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({'state': state, 'param_groups': param_groups})
+
+ def to_gpu(self):
+ self.checked_if_on_gpu = True
+ for gindex, group in enumerate(self.param_groups):
+ for pindex, p in enumerate(group['params']):
+ if p in self.state:
+ values = self.state[p]
+ for k, v in values.items():
+ if isinstance(v, torch.Tensor):
+ self.state[p][k] = v.to(p.device)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ overflows = []
+
+ if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
+ for gindex, group in enumerate(self.param_groups):
+ for pindex, p in enumerate(group['params']):
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self.init_state(group, p, gindex, pindex)
+
+ self.update_step(group, p, gindex, pindex)
+
+ return loss
+
+ def get_config(self, gindex, pindex, group):
+ config = {}
+ config['betas'] = group['betas']
+ config['eps'] = group['eps']
+ config['weight_decay'] = group['weight_decay']
+ config['lr'] = group['lr']
+ config['optim_bits'] = self.args.optim_bits
+ config['min_8bit_size'] = self.args.min_8bit_size
+ config['percentile_clipping'] = self.args.percentile_clipping
+ config['block_wise'] = self.args.block_wise
+ config['max_unorm'] = self.args.max_unorm
+
+ if (gindex, pindex) in self.mng.index2config:
+ config.update(self.mng.index2config[(gindex, pindex)])
+ return config
+
+ def init_state(self, group, p, gindex, pindex):
+ raise NotImplementedError(f'init_state method needs to be overidden')
+
+ def update_step(self, group, p, gindex, pindex):
+ raise NotImplementedError(f'The update_step method needs to be overidden')
+
+class Optimizer2State(Optimizer8bit):
+ def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0.0, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if isinstance(betas, str):
+ betas = eval(betas)
+ print(betas, 'parsed')
+ for i in range(len(betas)):
+ if not 0.0 <= betas[i] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay)
+ super(Optimizer2State, self).__init__(params, defaults, optim_bits)
+
+ if args is None:
+ args = {}
+ args['optim_bits'] = optim_bits
+ args['percentile_clipping'] = 100
+ args['min_8bit_size'] = min_8bit_size
+ args['percentile_clipping'] = percentile_clipping
+ args['block_wise'] = block_wise
+ args['max_unorm'] = max_unorm
+
+ self.args = MockArgs(args)
+ else:
+ self.args = args
+
+ self.optimizer_name = optimizer_name
+
+ @torch.no_grad()
+ def init_state(self, group, p, gindex, pindex):
+ config = self.get_config(gindex, pindex, group)
+
+ if config['optim_bits'] == 32:
+ dtype = torch.float32
+ elif config['optim_bits'] == 8:
+ dtype = torch.uint8
+ else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+
+ if p.numel() < config['min_8bit_size']: dtype = torch.float32
+
+ state = self.state[p]
+ state['step'] = 0
+
+ if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ elif dtype == torch.uint8:
+ if state['step'] == 0:
+ if 'dynamic' not in self.name2qmap: self.fill_qmap()
+ self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
+ self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
+
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
+ state['qmap1'] = self.name2qmap['dynamic']
+
+ state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
+ state['qmap2'] = self.name2qmap['udynamic']
+
+ if config['block_wise']:
+ n = p.numel()
+ blocks = n//2048
+ blocks += 1 if n % 2048 > 0 else 0
+
+ state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ else:
+ state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+
+ if config['percentile_clipping'] < 100:
+ state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+
+ if config['max_unorm'] > 0.0:
+ state['unorm_vec'] = torch.zeros((1,), device=p.device)
+
+ @torch.no_grad()
+ def update_step(self, group, p, gindex, pindex):
+ state = self.state[p]
+ grad = p.grad
+
+ config = self.get_config(gindex, pindex, group)
+
+ state['step'] += 1
+ step = state['step']
+
+ if config['percentile_clipping'] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ else:
+ gnorm_scale = 1.0
+
+ if state['state1'].dtype == torch.float:
+ F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
+ state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
+ F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'],
+ state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
+ config['weight_decay'], gnorm_scale=gnorm_scale,
+ unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ # swap maxes
+ state['max1'], state['new_max1'] = state['new_max1'], state['max1']
+ state['max2'], state['new_max2'] = state['new_max2'], state['max2']
+ elif state['state1'].dtype == torch.uint8 and config['block_wise']:
+ F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'],
+ state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
+ config['weight_decay'], gnorm_scale=gnorm_scale)
+
+
+class Optimizer1State(Optimizer8bit):
+ def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
+ weight_decay=0.0, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ for i in range(len(betas)):
+ if not 0.0 <= betas[i] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay)
+ super(Optimizer1State, self).__init__(params, defaults, optim_bits)
+
+ if args is None:
+ args = {}
+ args['optim_bits'] = optim_bits
+ args['percentile_clipping'] = 100
+ args['min_8bit_size'] = min_8bit_size
+ args['percentile_clipping'] = percentile_clipping
+ args['block_wise'] = block_wise
+ args['max_unorm'] = max_unorm
+
+ self.args = MockArgs(args)
+ else:
+ self.args = args
+
+ self.optimizer_name = optimizer_name
+
+ @torch.no_grad()
+ def init_state(self, group, p, gindex, pindex):
+ config = self.get_config(gindex, pindex, group)
+
+ if config['optim_bits'] == 32:
+ dtype = torch.float32
+ elif config['optim_bits'] == 8:
+ dtype = torch.uint8
+ else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+
+ if p.numel() < config['min_8bit_size']: dtype = torch.float32
+
+ state = self.state[p]
+ state['step'] = 0
+
+ if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ elif dtype == torch.uint8:
+ if state['step'] == 0:
+ if 'dynamic' not in self.name2qmap: self.fill_qmap()
+ self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
+
+ state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
+ state['qmap1'] = self.name2qmap['dynamic']
+
+ if config['block_wise']:
+ n = p.numel()
+ blocks = n//2048
+ blocks += 1 if n % 2048 > 0 else 0
+
+ state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ else:
+ state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+ state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
+
+ if config['percentile_clipping'] < 100:
+ state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+
+ if config['max_unorm'] > 0.0:
+ state['unorm_vec'] = torch.zeros((1,), device=p.device)
+
+
+ @torch.no_grad()
+ def update_step(self, group, p, gindex, pindex):
+ state = self.state[p]
+ grad = p.grad
+
+ config = self.get_config(gindex, pindex, group)
+
+ state['step'] += 1
+ step = state['step']
+
+ if config['percentile_clipping'] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ else:
+ gnorm_scale = 1.0
+
+ if state['state1'].dtype == torch.float:
+ F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
+ None, 0.0, config['weight_decay'], gnorm_scale,
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
+ F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
+ config['weight_decay'], gnorm_scale,
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+
+ state['max1'], state['new_max1'] = state['new_max1'], state['max1']
+ elif state['state1'].dtype == torch.uint8 and config['block_wise']:
+ F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
+ config['eps'], step, config['lr'],
+ state['qmap1'], None, state['absmax1'], None,
+ config['weight_decay'], gnorm_scale=gnorm_scale)
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
new file mode 100644
index 0000000..99b718e
--- /dev/null
+++ b/bitsandbytes/optim/rmsprop.py
@@ -0,0 +1,37 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class RMSprop(Optimizer1State):
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if alpha == 0:
+ raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
+ if centered:
+ raise NotImplementError(f'Centered RMSprop is not supported!')
+ super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class RMSprop8bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if alpha == 0:
+ raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
+ if centered:
+ raise NotImplementError(f'Centered RMSprop is not supported!')
+ super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class RMSprop32bit(Optimizer1State):
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+
+ if alpha == 0:
+ raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
+ if centered:
+ raise NotImplementError(f'Centered RMSprop is not supported!')
+ super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py
new file mode 100644
index 0000000..926d804
--- /dev/null
+++ b/bitsandbytes/optim/sgd.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class SGD(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+
+class SGD8bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+
+class SGD32bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ if momentum == 0:
+ raise NotImplementError(f'SGD without momentum is not supported!')
+ super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)