From bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Wed, 20 Oct 2021 18:37:44 -0700 Subject: Initial plumbing for skip_zeros. --- bitsandbytes/functional.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) (limited to 'bitsandbytes/functional.py') diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 65c697d..48ab40c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -337,7 +337,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten beta1: float, eps: float, step: int, lr: float, state2: Tensor=None, beta2: float=0.0, weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: + unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None: ''' Performs an inplace optimizer update with one or two optimizer states. @@ -369,6 +369,12 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten Optimizer beta2. gnorm_scale : float The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. + skip_zeros : bool + Whether to skip zero-valued gradients or not (default: False). ''' param_norm = 0.0 @@ -381,11 +387,11 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten if g.dtype == torch.float32 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.float32: str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') @@ -439,6 +445,10 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten Max value for the next Adam update of the second state. gnorm_scale : float The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. ''' param_norm = 0.0 @@ -468,19 +478,22 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, beta1: float, beta2: float, eps: float, step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0) -> None: + absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0, + skip_zeros=False) -> None: if g.dtype == torch.float32 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), ct.c_int32(g.numel())) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), ct.c_int32(g.numel())) else: raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') -- cgit v1.2.3