From 7439924891496025edf60c9da6a782f362a50c70 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Tue, 5 Oct 2021 19:16:20 -0700 Subject: Initial commit --- bitsandbytes/optim/__init__.py | 10 + bitsandbytes/optim/adam.py | 28 +++ bitsandbytes/optim/lamb.py | 29 +++ bitsandbytes/optim/lars.py | 115 ++++++++++ bitsandbytes/optim/optimizer.py | 460 ++++++++++++++++++++++++++++++++++++++++ bitsandbytes/optim/rmsprop.py | 37 ++++ bitsandbytes/optim/sgd.py | 32 +++ 7 files changed, 711 insertions(+) create mode 100644 bitsandbytes/optim/__init__.py create mode 100644 bitsandbytes/optim/adam.py create mode 100644 bitsandbytes/optim/lamb.py create mode 100644 bitsandbytes/optim/lars.py create mode 100644 bitsandbytes/optim/optimizer.py create mode 100644 bitsandbytes/optim/rmsprop.py create mode 100644 bitsandbytes/optim/sgd.py (limited to 'bitsandbytes/optim') diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py new file mode 100644 index 0000000..92c83b1 --- /dev/null +++ b/bitsandbytes/optim/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .adam import Adam, Adam8bit, Adam32bit +from .sgd import SGD, SGD8bit, SGD32bit +from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS +from .lamb import LAMB, LAMB8bit, LAMB32bit +from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .optimizer import GlobalOptimManager diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py new file mode 100644 index 0000000..99a6d10 --- /dev/null +++ b/bitsandbytes/optim/adam.py @@ -0,0 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer2State + +class Adam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(Adam, self).__init__('adam', params, lr, betas, eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class Adam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(Adam8bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class Adam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(Adam32bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + + diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py new file mode 100644 index 0000000..b8d4b1e --- /dev/null +++ b/bitsandbytes/optim/lamb.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import apex +from bitsandbytes.optim.optimizer import Optimizer2State + +class LAMB(Optimizer2State): + def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): + super(LAMB, self).__init__('lamb', params, lr, betas, eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + +class LAMB8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): + super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + +class LAMB32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): + super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + + diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py new file mode 100644 index 0000000..40dede7 --- /dev/null +++ b/bitsandbytes/optim/lars.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from torch.optim import Optimizer +from bitsandbytes.optim.optimizer import Optimizer1State + +class LARS(Optimizer1State): + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + if momentum == 0: + raise NotImplementError(f'LARS without momentum is not supported!') + super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + +class LARS8bit(Optimizer1State): + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False, args=None, + min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + if momentum == 0: + raise NotImplementError(f'LARS without momentum is not supported!') + super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, + weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + +class LARS32bit(Optimizer1State): + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False, args=None, + min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + if momentum == 0: + raise NotImplementError(f'LARS without momentum is not supported!') + super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, + weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + + +class PytorchLARS(Optimizer): + def __init__(self, params, lr=0.01, momentum=0, dampening=0, + weight_decay=0, nesterov=False, max_unorm=0.02): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(PytorchLARS, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PytorchLARS, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_buffer_list = [] + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + max_unorm = group['max_unorm'] + lr = group['lr'] + + for p in group['params']: + if p.grad is None: continue + + state = self.state[p] + d_p = p.grad + if weight_decay != 0: + d_p = d_p.add(param, alpha=weight_decay) + + if momentum != 0: + buf = state.get('momentum_buffer', None) + + if buf is None: + buf = torch.clone(d_p).detach() + state['momentum_buffer']= buf + else: + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + + if nesterov: + update = d_p + buf*momentum + else: + update = buf + + update_scale = 1.0 + if max_unorm > 0.0: + assert p.dtype == torch.float32 + pnorm = torch.norm(p.detach()) + unorm = torch.norm(update) + if unorm > max_unorm*pnorm: + update_scale = max_unorm*pnorm/unorm + + p.add_(update, alpha=-lr*update_scale) + + return loss diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py new file mode 100644 index 0000000..6743c15 --- /dev/null +++ b/bitsandbytes/optim/optimizer.py @@ -0,0 +1,460 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import bitsandbytes.functional as F + +from copy import deepcopy +from itertools import chain +from collections import defaultdict, abc as container_abcs + +class MockArgs(object): + def __init__(self, initial_data): + for key in initial_data: + setattr(self, key, initial_data[key]) + + +class GlobalOptimManager(object): + _instance = None + + def __init__(self): + raise RuntimeError('Call get_instance() instead') + + def initialize(self): + self.pid2config = {} + self.index2config = {} + self.optimizer = None + self.uses_config_override = False + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def register_parameters(self, params): + param_groups = list(params) + if not isinstance(param_groups[0], dict): + param_groups = [{'params': param_groups}] + + for group_index, group in enumerate(param_groups): + for p_index, p in enumerate(group['params']): + if id(p) in self.pid2config: + self.index2config[(group_index, p_index)] = self.pid2config[id(p)] + + def override_config(self, parameters, key=None, value=None, key_value_dict=None): + ''' + Overrides initial optimizer config for specific parameters. + + The key-values of the optimizer config for the input parameters are overidden + This can be both, optimizer parameters like "betas", or "lr" or it can be + 8-bit specific paramters like "optim_bits", "percentile_clipping". + + Parameters + ---------- + parameters : torch.Tensor or list(torch.Tensors) + The input parameters. + key : str + The hyperparamter to override. + value : object + The value for the hyperparamters. + key_value_dict : dict + A dictionary with multiple key-values to override. + ''' + self.uses_config_override = True + if isinstance(parameters, torch.nn.Parameter): + parameters = [parameters] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + if key is not None and value is not None: + assert key_value_dict is None + key_value_dict = {key: value} + + if key_value_dict is not None: + for p in parameters: + if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) + else: self.pid2config[id(p)] = key_value_dict + + +class Optimizer8bit(torch.optim.Optimizer): + + def __init__(self, params, defaults, optim_bits=32): + super(Optimizer8bit, self).__init__(params, defaults) + self.checked_if_on_gpu = False + self.name2qmap = {} + + self.mng = GlobalOptimManager.get_instance() + self.non_castable_tensor_keys = set( + ['qmap1', 'qmap2', + 'max1', 'max2', + 'new_max1', 'new_max2', + 'state1', 'state2', + 'gnorm_vec', 'absmax1', 'absmax2', + 'unorm_vec']) + + if optim_bits == 8: self.fill_qmap() + + def fill_qmap(self): + self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True) + self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False) + + def __setstate__(self, state): + super(Optimizer8bit, self).__setstate__(state) + + + def load_state_dict(self, state_dict): + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict['param_groups'] + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " + "parameter groups") + param_lens = (len(g['params']) for g in groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Update the state + id_map = {old_id: p for old_id, p in + zip(chain.from_iterable((g['params'] for g in saved_groups)), + chain.from_iterable((g['params'] for g in groups)))} + + def cast(param, value): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + if param.is_floating_point() and value.dtype != torch.uint8: + value = value.to(param.dtype) + return value + elif isinstance(value, dict): + for k, v in value.items(): + if k in self.non_castable_tensor_keys: + value[k] = v.to(param.device) + else: + value[k] = cast(param, v) + + return value + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + state[param] = cast(param, v) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({'state': state, 'param_groups': param_groups}) + + def to_gpu(self): + self.checked_if_on_gpu = True + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group['params']): + if p in self.state: + values = self.state[p] + for k, v in values.items(): + if isinstance(v, torch.Tensor): + self.state[p][k] = v.to(p.device) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + overflows = [] + + if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group['params']): + if p.grad is None: + continue + state = self.state[p] + if len(state) == 0: + self.init_state(group, p, gindex, pindex) + + self.update_step(group, p, gindex, pindex) + + return loss + + def get_config(self, gindex, pindex, group): + config = {} + config['betas'] = group['betas'] + config['eps'] = group['eps'] + config['weight_decay'] = group['weight_decay'] + config['lr'] = group['lr'] + config['optim_bits'] = self.args.optim_bits + config['min_8bit_size'] = self.args.min_8bit_size + config['percentile_clipping'] = self.args.percentile_clipping + config['block_wise'] = self.args.block_wise + config['max_unorm'] = self.args.max_unorm + + if (gindex, pindex) in self.mng.index2config: + config.update(self.mng.index2config[(gindex, pindex)]) + return config + + def init_state(self, group, p, gindex, pindex): + raise NotImplementedError(f'init_state method needs to be overidden') + + def update_step(self, group, p, gindex, pindex): + raise NotImplementedError(f'The update_step method needs to be overidden') + +class Optimizer2State(Optimizer8bit): + def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0.0, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if isinstance(betas, str): + betas = eval(betas) + print(betas, 'parsed') + for i in range(len(betas)): + if not 0.0 <= betas[i] < 1.0: + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + super(Optimizer2State, self).__init__(params, defaults, optim_bits) + + if args is None: + args = {} + args['optim_bits'] = optim_bits + args['percentile_clipping'] = 100 + args['min_8bit_size'] = min_8bit_size + args['percentile_clipping'] = percentile_clipping + args['block_wise'] = block_wise + args['max_unorm'] = max_unorm + + self.args = MockArgs(args) + else: + self.args = args + + self.optimizer_name = optimizer_name + + @torch.no_grad() + def init_state(self, group, p, gindex, pindex): + config = self.get_config(gindex, pindex, group) + + if config['optim_bits'] == 32: + dtype = torch.float32 + elif config['optim_bits'] == 8: + dtype = torch.uint8 + else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + + if p.numel() < config['min_8bit_size']: dtype = torch.float32 + + state = self.state[p] + state['step'] = 0 + + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + elif dtype == torch.uint8: + if state['step'] == 0: + if 'dynamic' not in self.name2qmap: self.fill_qmap() + self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) + self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) + + state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) + state['qmap1'] = self.name2qmap['dynamic'] + + state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) + state['qmap2'] = self.name2qmap['udynamic'] + + if config['block_wise']: + n = p.numel() + blocks = n//2048 + blocks += 1 if n % 2048 > 0 else 0 + + state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + else: + state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + + if config['percentile_clipping'] < 100: + state['gnorm_vec'] = torch.zeros((100,), device=p.device) + + if config['max_unorm'] > 0.0: + state['unorm_vec'] = torch.zeros((1,), device=p.device) + + @torch.no_grad() + def update_step(self, group, p, gindex, pindex): + state = self.state[p] + grad = p.grad + + config = self.get_config(gindex, pindex, group) + + state['step'] += 1 + step = state['step'] + + if config['percentile_clipping'] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + else: + gnorm_scale = 1.0 + + if state['state1'].dtype == torch.float: + F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], + state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale, + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + + elif state['state1'].dtype == torch.uint8 and not config['block_wise']: + F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], + config['eps'], step, config['lr'], + state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'], + config['weight_decay'], gnorm_scale=gnorm_scale, + unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + + # swap maxes + state['max1'], state['new_max1'] = state['new_max1'], state['max1'] + state['max2'], state['new_max2'] = state['new_max2'], state['max2'] + elif state['state1'].dtype == torch.uint8 and config['block_wise']: + F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], + config['eps'], step, config['lr'], + state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'], + config['weight_decay'], gnorm_scale=gnorm_scale) + + +class Optimizer1State(Optimizer8bit): + def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, + weight_decay=0.0, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + for i in range(len(betas)): + if not 0.0 <= betas[i] < 1.0: + raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + super(Optimizer1State, self).__init__(params, defaults, optim_bits) + + if args is None: + args = {} + args['optim_bits'] = optim_bits + args['percentile_clipping'] = 100 + args['min_8bit_size'] = min_8bit_size + args['percentile_clipping'] = percentile_clipping + args['block_wise'] = block_wise + args['max_unorm'] = max_unorm + + self.args = MockArgs(args) + else: + self.args = args + + self.optimizer_name = optimizer_name + + @torch.no_grad() + def init_state(self, group, p, gindex, pindex): + config = self.get_config(gindex, pindex, group) + + if config['optim_bits'] == 32: + dtype = torch.float32 + elif config['optim_bits'] == 8: + dtype = torch.uint8 + else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + + if p.numel() < config['min_8bit_size']: dtype = torch.float32 + + state = self.state[p] + state['step'] = 0 + + if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + elif dtype == torch.uint8: + if state['step'] == 0: + if 'dynamic' not in self.name2qmap: self.fill_qmap() + self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) + + state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) + state['qmap1'] = self.name2qmap['dynamic'] + + if config['block_wise']: + n = p.numel() + blocks = n//2048 + blocks += 1 if n % 2048 > 0 else 0 + + state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + else: + state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + + if config['percentile_clipping'] < 100: + state['gnorm_vec'] = torch.zeros((100,), device=p.device) + + if config['max_unorm'] > 0.0: + state['unorm_vec'] = torch.zeros((1,), device=p.device) + + + @torch.no_grad() + def update_step(self, group, p, gindex, pindex): + state = self.state[p] + grad = p.grad + + config = self.get_config(gindex, pindex, group) + + state['step'] += 1 + step = state['step'] + + if config['percentile_clipping'] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + else: + gnorm_scale = 1.0 + + if state['state1'].dtype == torch.float: + F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], + None, 0.0, config['weight_decay'], gnorm_scale, + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + + elif state['state1'].dtype == torch.uint8 and not config['block_wise']: + F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], + config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None, + config['weight_decay'], gnorm_scale, + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + + state['max1'], state['new_max1'] = state['new_max1'], state['max1'] + elif state['state1'].dtype == torch.uint8 and config['block_wise']: + F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], + config['eps'], step, config['lr'], + state['qmap1'], None, state['absmax1'], None, + config['weight_decay'], gnorm_scale=gnorm_scale) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py new file mode 100644 index 0000000..99b718e --- /dev/null +++ b/bitsandbytes/optim/rmsprop.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from bitsandbytes.optim.optimizer import Optimizer1State + +class RMSprop(Optimizer1State): + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if alpha == 0: + raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!') + if centered: + raise NotImplementError(f'Centered RMSprop is not supported!') + super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class RMSprop8bit(Optimizer1State): + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if alpha == 0: + raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!') + if centered: + raise NotImplementError(f'Centered RMSprop is not supported!') + super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class RMSprop32bit(Optimizer1State): + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + + if alpha == 0: + raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!') + if centered: + raise NotImplementError(f'Centered RMSprop is not supported!') + super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py new file mode 100644 index 0000000..926d804 --- /dev/null +++ b/bitsandbytes/optim/sgd.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + +class SGD(Optimizer1State): + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if momentum == 0: + raise NotImplementError(f'SGD without momentum is not supported!') + super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class SGD8bit(Optimizer1State): + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if momentum == 0: + raise NotImplementError(f'SGD without momentum is not supported!') + super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class SGD32bit(Optimizer1State): + def __init__(self, params, lr, momentum=0, dampening=0, + weight_decay=0, nesterov=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + if momentum == 0: + raise NotImplementError(f'SGD without momentum is not supported!') + super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) -- cgit v1.2.3