From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- bitsandbytes/functional.py | 25 +++++++++++++++++++------ bitsandbytes/optim/optimizer.py | 14 ++++++++++---- 2 files changed, 29 insertions(+), 10 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65c697d..48ab40c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -337,7 +337,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten beta1: float, eps: float, step: int, lr: float, state2: Tensor=None, beta2: float=0.0, weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: + unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None: ''' Performs an inplace optimizer update with one or two optimizer states. @@ -369,6 +369,12 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten Optimizer beta2. gnorm_scale : float The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. + skip_zeros : bool + Whether to skip zero-valued gradients or not (default: False). ''' param_norm = 0.0 @@ -381,11 +387,11 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten if g.dtype == torch.float32 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') @@ -439,6 +445,10 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten Max value for the next Adam update of the second state. gnorm_scale : float The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. ''' param_norm = 0.0 @@ -468,19 +478,22 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0) -> None: + absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0, + skip_zeros=False) -> None: if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 6743c15..25512b1 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer): config['percentile_clipping'] = self.args.percentile_clipping config['block_wise'] = self.args.block_wise config['max_unorm'] = self.args.max_unorm + config['skip_zeros'] = self.args.skip_zeros if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) @@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer): class Optimizer2State(Optimizer8bit): def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, + skip_zeros=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit): args['percentile_clipping'] = percentile_clipping args['block_wise'] = block_wise args['max_unorm'] = max_unorm + args['skip_zeros'] = skip_zeros self.args = MockArgs(args) else: @@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit): class Optimizer1State(Optimizer8bit): def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, + skip_zeros=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit): args['percentile_clipping'] = percentile_clipping args['block_wise'] = block_wise args['max_unorm'] = max_unorm + args['skip_zeros'] = skip_zeros self.args = MockArgs(args) else: @@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit): if state['state1'].dtype == torch.float: F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], None, 0.0, config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], + skip_zeros=False) elif state['state1'].dtype == torch.uint8 and not config['block_wise']: F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], @@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit): F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], config['eps'], step, config['lr'], state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale) + config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False) -- cgit v1.2.3