summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
authorTim Dettmers <dettmers@cs.washington.edu>2021-10-20 19:15:47 -0700
committerTim Dettmers <dettmers@cs.washington.edu>2021-10-20 19:15:47 -0700
commita6eae2e7f2bf03f268fcb6b055201ff6827684c4 (patch)
treed2f72792251c9feaef1cf9dcddc3c79e6312a93a /bitsandbytes
parentbb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 (diff)
Added skip_zeros; tests are passing.
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/functional.py4
-rw-r--r--bitsandbytes/optim/optimizer.py8
2 files changed, 6 insertions, 6 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 48ab40c..9fe1345 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -486,13 +486,13 @@ def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, s
str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
+ get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2),
ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps),
ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2),
- get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
+ get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros), ct.c_int32(g.numel()))
else:
raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}')
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 25512b1..4b70b5c 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit):
if state['state1'].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
@@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit):
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
- config['weight_decay'], gnorm_scale=gnorm_scale)
+ config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
class Optimizer1State(Optimizer8bit):
@@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit):
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
None, 0.0, config['weight_decay'], gnorm_scale,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
- skip_zeros=False)
+ skip_zeros=config['skip_zeros'])
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
@@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit):
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
config['eps'], step, config['lr'],
state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False)
+ config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])