From a6eae2e7f2bf03f268fcb6b055201ff6827684c4 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 19:15:47 -0700 Subject: Added skip_zeros; tests are passing. --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 48ab40c..9fe1345 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -486,13 +486,13 @@ def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, s str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') -- cgit v1.2.3