From bfa0e33294f2b1dc25e65a33be2397f989824298 Mon Sep 17 00:00:00 2001 From: Titus von Koeller Date: Mon, 1 Aug 2022 03:31:48 -0700 Subject: ran black and isort for coherent code formatting --- bitsandbytes/optim/__init__.py | 6 +- bitsandbytes/optim/adagrad.py | 114 ++++++-- bitsandbytes/optim/adam.py | 179 +++++++++---- bitsandbytes/optim/adamw.py | 104 ++++++-- bitsandbytes/optim/lamb.py | 117 +++++++-- bitsandbytes/optim/lars.py | 167 +++++++++--- bitsandbytes/optim/optimizer.py | 565 +++++++++++++++++++++++++++------------- bitsandbytes/optim/rmsprop.py | 115 ++++++-- bitsandbytes/optim/sgd.py | 109 ++++++-- 9 files changed, 1094 insertions(+), 382 deletions(-) (limited to 'bitsandbytes/optim') diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 42b5bc0..a76d717 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.cextension import COMPILED_WITH_CUDA diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 4f51250..43e3973 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -1,12 +1,25 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Adagrad(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: @@ -14,15 +27,39 @@ class Adagrad(Optimizer1State): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') - super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise ValueError("Lr Decay != 0.0 not supported!") + super(Adagrad, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adagrad8bit(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=8, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: @@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') + raise ValueError("Lr Decay != 0.0 not supported!") assert block_wise - super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + super(Adagrad8bit, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adagrad32bit(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: @@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') - super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise ValueError("Lr Decay != 0.0 not supported!") + super(Adagrad32bit, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index ed1b9f0..5cfaa28 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math @@ -8,29 +8,97 @@ import os import torch import torch.distributed as dist -from bitsandbytes.optim.optimizer import Optimizer2State + import bitsandbytes.functional as F +from bitsandbytes.optim.optimizer import Optimizer2State + class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) class AnalysisAdam(torch.optim.Optimizer): @@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer): eps=1e-8, weight_decay=0, amsgrad=False, - bnb_analysis='dynamic-blockwise', - savedir=None + bnb_analysis="dynamic-blockwise", + savedir=None, ): defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad @@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device) + state["abserrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["relerrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer): bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - e = state['abserrors'] - rele = state['relerrors'] - counts = state['counts'] + e = state["abserrors"] + rele = state["relerrors"] + counts = state["counts"] if group["weight_decay"] != 0: p_data_fp32.add_( @@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer): if amsgrad: max_exp_avg_sq = state["max_exp_avg_sq"] - # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) - update_fp32 = exp_avg/denom + update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000: + if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: # embedding layer or too small - p_data_fp32 += -step_size*update_fp32 + p_data_fp32 += -step_size * update_fp32 else: - if self.analysis == 'dynamic-blockwise': + if self.analysis == "dynamic-blockwise": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize_blockwise(exp_avg, code=code1) state1 = F.dequantize_blockwise(C1, S1) C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) state2 = F.dequantize_blockwise(C2, S2) - elif self.analysis == 'dynamic': + elif self.analysis == "dynamic": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'linear': + elif self.analysis == "linear": code1 = F.create_linear_map(signed=True).to(p.device) code2 = F.create_linear_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'quantile': + elif self.analysis == "quantile": code1 = F.estimate_quantiles(exp_avg) code2 = F.estimate_quantiles(exp_avg_sq) C1 = F.quantize_no_absmax(exp_avg, code=code1) state1 = F.dequantize_no_absmax(C1, code1) C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) state2 = F.dequantize_no_absmax(C2, code2) - elif self.analysis == 'my-quantization-routine': + elif self.analysis == "my-quantization-routine": pass # 1. get code # 2. quantize # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f'Invalid analysis value: {self.analysis}!') + raise ValueError(f"Invalid analysis value: {self.analysis}!") denom = state2.sqrt().add_(group["eps"]) - update_8bit = state1/denom + update_8bit = state1 / denom - abserr = torch.abs(update_8bit-update_fp32) - relerr = abserr/torch.abs(update_fp32+1e-6) + abserr = torch.abs(update_8bit - update_fp32) + relerr = abserr / torch.abs(update_fp32 + 1e-6) C1, C2 = C1.int(), C2.int() F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) - - p_data_fp32 += -step_size*update_fp32 + F.histogram_scatter_add_2d( + counts, C1.int(), C2.int(), torch.ones_like(abserr) + ) + p_data_fp32 += -step_size * update_fp32 if not dist.is_initialized() or dist.get_rank() == 0: - if self.savedir != '' and state['step'] % 100 == 0: - if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) - pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') - pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') - pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') + if self.savedir != "" and state["step"] % 100 == 0: + if not os.path.exists(self.savedir): + os.makedirs(self.savedir) + shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + pathe = os.path.join( + self.savedir, f"{p_id}_{shapestr}_abserr.pkl" + ) + pathrele = os.path.join( + self.savedir, f"{p_id}_{shapestr}_relerr.pkl" + ) + pathcounts = os.path.join( + self.savedir, f"{p_id}_{shapestr}_counts.pkl" + ) torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) @@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer): if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) - - return loss diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index c4f0355..d0b3bde 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -1,27 +1,93 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State + class AdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class AdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) -class AdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) +class AdamW32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 58cc13d..8f365f7 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -1,28 +1,105 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State + class LAMB(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB, self).__init__('lamb', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) -class LAMB8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) -class LAMB32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) +class LAMB8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB8bit, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) +class LAMB32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB32bit, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 912520d..c6cf5c6 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -1,43 +1,121 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch - from torch.optim import Optimizer + from bitsandbytes.optim.optimizer import Optimizer1State + class LARS(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError(f"LARS without momentum is not supported!") + super(LARS, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + class LARS8bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError(f"LARS without momentum is not supported!") + super(LARS8bit, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + class LARS32bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError(f"LARS without momentum is not supported!") + super(LARS32bit, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) class PytorchLARS(Optimizer): - def __init__(self, params, lr=0.01, momentum=0, dampening=0, - weight_decay=0, nesterov=False, max_unorm=0.02): + def __init__( + self, + params, + lr=0.01, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + max_unorm=0.02, + ): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -45,8 +123,14 @@ class PytorchLARS(Optimizer): if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm) + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + max_unorm=max_unorm, + ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super(PytorchLARS, self).__init__(params, defaults) @@ -54,7 +138,7 @@ class PytorchLARS(Optimizer): def __setstate__(self, state): super(PytorchLARS, self).__setstate__(state) for group in self.param_groups: - group.setdefault('nesterov', False) + group.setdefault("nesterov", False) @torch.no_grad() def step(self, closure=None): @@ -73,15 +157,16 @@ class PytorchLARS(Optimizer): params_with_grad = [] d_p_list = [] momentum_buffer_list = [] - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] - max_unorm = group['max_unorm'] - lr = group['lr'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + max_unorm = group["max_unorm"] + lr = group["lr"] - for p in group['params']: - if p.grad is None: continue + for p in group["params"]: + if p.grad is None: + continue state = self.state[p] d_p = p.grad @@ -89,16 +174,16 @@ class PytorchLARS(Optimizer): d_p = d_p.add(param, alpha=weight_decay) if momentum != 0: - buf = state.get('momentum_buffer', None) + buf = state.get("momentum_buffer", None) if buf is None: buf = torch.clone(d_p).detach() - state['momentum_buffer']= buf + state["momentum_buffer"] = buf else: buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: - update = d_p + buf*momentum + update = d_p + buf * momentum else: update = buf @@ -107,9 +192,9 @@ class PytorchLARS(Optimizer): assert p.dtype == torch.float32 pnorm = torch.norm(p.detach()) unorm = torch.norm(update) - if unorm > max_unorm*pnorm: - update_scale = max_unorm*pnorm/unorm + if unorm > max_unorm * pnorm: + update_scale = max_unorm * pnorm / unorm - p.add_(update, alpha=-lr*update_scale) + p.add_(update, alpha=-lr * update_scale) return loss diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 5a5bb1e..b942e34 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -1,13 +1,16 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import abc as container_abcs +from collections import defaultdict +from copy import deepcopy +from itertools import chain + import torch + import bitsandbytes.functional as F -from copy import deepcopy -from itertools import chain -from collections import defaultdict, abc as container_abcs class MockArgs(object): def __init__(self, initial_data): @@ -19,7 +22,7 @@ class GlobalOptimManager(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.pid2config = {} @@ -38,15 +41,15 @@ class GlobalOptimManager(object): def register_parameters(self, params): param_groups = list(params) if not isinstance(param_groups[0], dict): - param_groups = [{'params': param_groups}] + param_groups = [{"params": param_groups}] for group_index, group in enumerate(param_groups): - for p_index, p in enumerate(group['params']): + for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: self.index2config[(group_index, p_index)] = self.pid2config[id(p)] def override_config(self, parameters, key=None, value=None, key_value_dict=None): - ''' + """ Overrides initial optimizer config for specific parameters. The key-values of the optimizer config for the input parameters are overidden @@ -63,7 +66,7 @@ class GlobalOptimManager(object): The value for the hyperparamters. key_value_dict : dict A dictionary with multiple key-values to override. - ''' + """ self.uses_config_override = True if isinstance(parameters, torch.nn.Parameter): parameters = [parameters] @@ -75,16 +78,16 @@ class GlobalOptimManager(object): if key_value_dict is not None: for p in parameters: - if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) - else: self.pid2config[id(p)] = key_value_dict + if id(p) in self.pid2config: + self.pid2config[id(p)].update(key_value_dict) + else: + self.pid2config[id(p)] = key_value_dict def register_module_override(self, module, param_name, config): self.module_weight_config_triple.append((module, param_name, config)) - class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): super(Optimizer8bit, self).__init__(params, defaults) self.initialized = False @@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = set( - ['qmap1', 'qmap2', - 'max1', 'max2', - 'new_max1', 'new_max2', - 'state1', 'state2', - 'gnorm_vec', 'absmax1', 'absmax2', - 'unorm_vec']) - - if optim_bits == 8: self.fill_qmap() + [ + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", + ] + ) + + if optim_bits == 8: + self.fill_qmap() def fill_qmap(self): - self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True) - self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False) + self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) + self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) def __setstate__(self, state): super(Optimizer8bit, self).__setstate__(state) - def load_state_dict(self, state_dict): r"""Loads the optimizer state. @@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer): state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups - saved_groups = state_dict['param_groups'] + saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } def cast(param, value): r"""Make a deep copy of value, casting all tensors to device of param.""" @@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer): # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) - for k, v in state_dict['state'].items(): + for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] state[param] = cast(param, v) @@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer): # Update parameter groups, setting their 'params' value def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group['params']): + for pindex, p in enumerate(group["params"]): if p in self.state: values = self.state[p] for k, v in values.items(): @@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + assert isinstance(pmodule, torch.Tensor) or isinstance( + pmodule, torch.Parameter + ) found = False for gindex, group in enumerate(self.param_groups): - if found: break - for pindex, p in enumerate(group['params']): - if found: break + if found: + break + for pindex, p in enumerate(group["params"]): + if found: + break if id(p) == id(pmodule): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[ + id(p) + ] found = True @torch.no_grad() @@ -219,11 +244,11 @@ class Optimizer8bit(torch.optim.Optimizer): if not self.initialized: self.check_overrides() - self.to_gpu() # needed for fairseq pure fp16 training + self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group['params']): + for pindex, p in enumerate(group["params"]): if p.grad is None: continue state = self.state[p] @@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer): def get_config(self, gindex, pindex, group): config = {} - config['betas'] = group['betas'] - config['eps'] = group['eps'] - config['weight_decay'] = group['weight_decay'] - config['lr'] = group['lr'] - config['optim_bits'] = self.args.optim_bits - config['min_8bit_size'] = self.args.min_8bit_size - config['percentile_clipping'] = self.args.percentile_clipping - config['block_wise'] = self.args.block_wise - config['max_unorm'] = self.args.max_unorm - config['skip_zeros'] = self.args.skip_zeros + config["betas"] = group["betas"] + config["eps"] = group["eps"] + config["weight_decay"] = group["weight_decay"] + config["lr"] = group["lr"] + config["optim_bits"] = self.args.optim_bits + config["min_8bit_size"] = self.args.min_8bit_size + config["percentile_clipping"] = self.args.percentile_clipping + config["block_wise"] = self.args.block_wise + config["max_unorm"] = self.args.max_unorm + config["skip_zeros"] = self.args.skip_zeros if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) return config def init_state(self, group, p, gindex, pindex): - raise NotImplementedError(f'init_state method needs to be overidden') + raise NotImplementedError(f"init_state method needs to be overidden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError(f'The update_step method needs to be overidden') + raise NotImplementedError(f"The update_step method needs to be overidden") + class Optimizer2State(Optimizer8bit): - def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, - skip_zeros=False): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if isinstance(betas, str): # format: '(beta1, beta2)' - betas = betas.replace('(', '').replace(')', '').strip().split(',') + betas = betas.replace("(", "").replace(")", "").strip().split(",") betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer2State, self).__init__(params, defaults, optim_bits) if args is None: args = {} - args['optim_bits'] = optim_bits - args['percentile_clipping'] = 100 - args['min_8bit_size'] = min_8bit_size - args['percentile_clipping'] = percentile_clipping - args['block_wise'] = block_wise - args['max_unorm'] = max_unorm - args['skip_zeros'] = skip_zeros + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros self.args = MockArgs(args) else: @@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit): def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) - if config['optim_bits'] == 32: + if config["optim_bits"] == 32: dtype = torch.float32 - elif config['optim_bits'] == 8: + elif config["optim_bits"] == 8: dtype = torch.uint8 - else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) - if p.numel() < config['min_8bit_size']: dtype = torch.float32 + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 state = self.state[p] - state['step'] = 0 + state["step"] = 0 if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) - state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) + state["state2"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) elif dtype == torch.uint8: - if state['step'] == 0: - if 'dynamic' not in self.name2qmap: self.fill_qmap() - self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) - self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) - - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap1'] = self.name2qmap['dynamic'] - - state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap2'] = self.name2qmap['udynamic'] - - if config['block_wise']: + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) + + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap1"] = self.name2qmap["dynamic"] + + state["state2"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap2"] = self.name2qmap["udynamic"] + + if config["block_wise"]: n = p.numel() - blocks = n//2048 + blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + state["absmax2"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) - if config['percentile_clipping'] < 100: - state['gnorm_vec'] = torch.zeros((100,), device=p.device) + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) - if config['max_unorm'] > 0.0: - state['unorm_vec'] = torch.zeros((1,), device=p.device) + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() def update_step(self, group, p, gindex, pindex): @@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit): config = self.get_config(gindex, pindex, group) - state['step'] += 1 - step = state['step'] + state["step"] += 1 + step = state["step"] - if config['percentile_clipping'] < 100: - current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) else: gnorm_scale = 1.0 - if state['state1'].dtype == torch.float: - F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], - state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros']) - - elif state['state1'].dtype == torch.uint8 and not config['block_wise']: - F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'], - config['weight_decay'], gnorm_scale=gnorm_scale, - unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + state["state2"], + config["betas"][1], + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) + + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["max1"], + state["max2"], + state["new_max1"], + state["new_max2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + ) # swap maxes - state['max1'], state['new_max1'] = state['new_max1'], state['max1'] - state['max2'], state['new_max2'] = state['new_max2'], state['max2'] - elif state['state1'].dtype == torch.uint8 and config['block_wise']: - F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'], - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + state["max2"], state["new_max2"] = state["new_max2"], state["max2"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["absmax1"], + state["absmax2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) class Optimizer1State(Optimizer8bit): - def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, - weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, - skip_zeros=False): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.0), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit): raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer1State, self).__init__(params, defaults, optim_bits) if args is None: args = {} - args['optim_bits'] = optim_bits - args['percentile_clipping'] = 100 - args['min_8bit_size'] = min_8bit_size - args['percentile_clipping'] = percentile_clipping - args['block_wise'] = block_wise - args['max_unorm'] = max_unorm - args['skip_zeros'] = skip_zeros + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros self.args = MockArgs(args) else: @@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit): def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) - if config['optim_bits'] == 32: + if config["optim_bits"] == 32: dtype = torch.float32 - elif config['optim_bits'] == 8: + elif config["optim_bits"] == 8: dtype = torch.uint8 - else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) - if p.numel() < config['min_8bit_size']: dtype = torch.float32 + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 state = self.state[p] - state['step'] = 0 + state["step"] = 0 if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) elif dtype == torch.uint8: - if state['step'] == 0: - if 'dynamic' not in self.name2qmap: self.fill_qmap() - self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) - - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap1'] = self.name2qmap['dynamic'] - - if config['block_wise']: + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap1"] = self.name2qmap["dynamic"] + + if config["block_wise"]: n = p.numel() - blocks = n//2048 + blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - - if config['percentile_clipping'] < 100: - state['gnorm_vec'] = torch.zeros((100,), device=p.device) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) - if config['max_unorm'] > 0.0: - state['unorm_vec'] = torch.zeros((1,), device=p.device) + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() def update_step(self, group, p, gindex, pindex): @@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit): config = self.get_config(gindex, pindex, group) - state['step'] += 1 - step = state['step'] + state["step"] += 1 + step = state["step"] - if config['percentile_clipping'] < 100: - current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) else: gnorm_scale = 1.0 - if state['state1'].dtype == torch.float: - F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], - None, 0.0, config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], - skip_zeros=config['skip_zeros']) - - elif state['state1'].dtype == torch.uint8 and not config['block_wise']: - F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None, - config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) - - state['max1'], state['new_max1'] = state['new_max1'], state['max1'] - elif state['state1'].dtype == torch.uint8 and config['block_wise']: - F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + None, + 0.0, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) + + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["max1"], + None, + state["new_max1"], + None, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + ) + + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["absmax1"], + None, + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 0f1ffaa..679f783 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -1,36 +1,109 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class RMSprop(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class RMSprop8bit(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop8bit, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class RMSprop32bit(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop32bit, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index 0529879..f7b8934 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -1,32 +1,99 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class SGD(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class SGD8bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD8bit, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class SGD32bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD32bit, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) -- cgit v1.2.3