diff options
Diffstat (limited to 'bitsandbytes/optim/adam.py')
-rw-r--r-- | bitsandbytes/optim/adam.py | 179 |
1 files changed, 128 insertions, 51 deletions
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index ed1b9f0..5cfaa28 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math @@ -8,29 +8,97 @@ import os import torch import torch.distributed as dist -from bitsandbytes.optim.optimizer import Optimizer2State + import bitsandbytes.functional as F +from bitsandbytes.optim.optimizer import Optimizer2State + class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) class AnalysisAdam(torch.optim.Optimizer): @@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer): eps=1e-8, weight_decay=0, amsgrad=False, - bnb_analysis='dynamic-blockwise', - savedir=None + bnb_analysis="dynamic-blockwise", + savedir=None, ): defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad @@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device) + state["abserrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["relerrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer): bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - e = state['abserrors'] - rele = state['relerrors'] - counts = state['counts'] + e = state["abserrors"] + rele = state["relerrors"] + counts = state["counts"] if group["weight_decay"] != 0: p_data_fp32.add_( @@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer): if amsgrad: max_exp_avg_sq = state["max_exp_avg_sq"] - # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) - update_fp32 = exp_avg/denom + update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000: + if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: # embedding layer or too small - p_data_fp32 += -step_size*update_fp32 + p_data_fp32 += -step_size * update_fp32 else: - if self.analysis == 'dynamic-blockwise': + if self.analysis == "dynamic-blockwise": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize_blockwise(exp_avg, code=code1) state1 = F.dequantize_blockwise(C1, S1) C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) state2 = F.dequantize_blockwise(C2, S2) - elif self.analysis == 'dynamic': + elif self.analysis == "dynamic": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'linear': + elif self.analysis == "linear": code1 = F.create_linear_map(signed=True).to(p.device) code2 = F.create_linear_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'quantile': + elif self.analysis == "quantile": code1 = F.estimate_quantiles(exp_avg) code2 = F.estimate_quantiles(exp_avg_sq) C1 = F.quantize_no_absmax(exp_avg, code=code1) state1 = F.dequantize_no_absmax(C1, code1) C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) state2 = F.dequantize_no_absmax(C2, code2) - elif self.analysis == 'my-quantization-routine': + elif self.analysis == "my-quantization-routine": pass # 1. get code # 2. quantize # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f'Invalid analysis value: {self.analysis}!') + raise ValueError(f"Invalid analysis value: {self.analysis}!") denom = state2.sqrt().add_(group["eps"]) - update_8bit = state1/denom + update_8bit = state1 / denom - abserr = torch.abs(update_8bit-update_fp32) - relerr = abserr/torch.abs(update_fp32+1e-6) + abserr = torch.abs(update_8bit - update_fp32) + relerr = abserr / torch.abs(update_fp32 + 1e-6) C1, C2 = C1.int(), C2.int() F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) - - p_data_fp32 += -step_size*update_fp32 + F.histogram_scatter_add_2d( + counts, C1.int(), C2.int(), torch.ones_like(abserr) + ) + p_data_fp32 += -step_size * update_fp32 if not dist.is_initialized() or dist.get_rank() == 0: - if self.savedir != '' and state['step'] % 100 == 0: - if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) - pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') - pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') - pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') + if self.savedir != "" and state["step"] % 100 == 0: + if not os.path.exists(self.savedir): + os.makedirs(self.savedir) + shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + pathe = os.path.join( + self.savedir, f"{p_id}_{shapestr}_abserr.pkl" + ) + pathrele = os.path.join( + self.savedir, f"{p_id}_{shapestr}_relerr.pkl" + ) + pathcounts = os.path.join( + self.savedir, f"{p_id}_{shapestr}_counts.pkl" + ) torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) @@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer): if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) - - return loss |