From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- bitsandbytes/optim/optimizer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'bitsandbytes/optim') diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 6743c15..25512b1 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer): config['percentile_clipping'] = self.args.percentile_clipping config['block_wise'] = self.args.block_wise config['max_unorm'] = self.args.max_unorm + config['skip_zeros'] = self.args.skip_zeros if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) @@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer): class Optimizer2State(Optimizer8bit): def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, + skip_zeros=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit): args['percentile_clipping'] = percentile_clipping args['block_wise'] = block_wise args['max_unorm'] = max_unorm + args['skip_zeros'] = skip_zeros self.args = MockArgs(args) else: @@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit): class Optimizer1State(Optimizer8bit): def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0): + min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, + skip_zeros=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit): args['percentile_clipping'] = percentile_clipping args['block_wise'] = block_wise args['max_unorm'] = max_unorm + args['skip_zeros'] = skip_zeros self.args = MockArgs(args) else: @@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit): if state['state1'].dtype == torch.float: F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], None, 0.0, config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], + skip_zeros=False) elif state['state1'].dtype == torch.uint8 and not config['block_wise']: F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], @@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit): F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], config['eps'], step, config['lr'], state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale) + config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False) -- cgit v1.2.3