summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/optimizer.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 6743c15..25512b1 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -220,6 +220,7 @@ class Optimizer8bit(torch.optim.Optimizer):
config['percentile_clipping'] = self.args.percentile_clipping
config['block_wise'] = self.args.block_wise
config['max_unorm'] = self.args.max_unorm
+ config['skip_zeros'] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
@@ -234,7 +235,8 @@ class Optimizer8bit(torch.optim.Optimizer):
class Optimizer2State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
+ skip_zeros=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -259,6 +261,7 @@ class Optimizer2State(Optimizer8bit):
args['percentile_clipping'] = percentile_clipping
args['block_wise'] = block_wise
args['max_unorm'] = max_unorm
+ args['skip_zeros'] = skip_zeros
self.args = MockArgs(args)
else:
@@ -355,7 +358,8 @@ class Optimizer2State(Optimizer8bit):
class Optimizer1State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
+ min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
+ skip_zeros=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
@@ -377,6 +381,7 @@ class Optimizer1State(Optimizer8bit):
args['percentile_clipping'] = percentile_clipping
args['block_wise'] = block_wise
args['max_unorm'] = max_unorm
+ args['skip_zeros'] = skip_zeros
self.args = MockArgs(args)
else:
@@ -444,7 +449,8 @@ class Optimizer1State(Optimizer8bit):
if state['state1'].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
None, 0.0, config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
+ skip_zeros=False)
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
@@ -457,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale)
+ config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)