summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/lars.py
diff options
context:
space:
mode:
authorTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
committerTim Dettmers <tim.dettmers@gmail.com>2021-10-05 19:16:20 -0700
commit7439924891496025edf60c9da6a782f362a50c70 (patch)
tree90476984d2c267f89232577a2ea40eb172387475 /bitsandbytes/optim/lars.py
Initial commit
Diffstat (limited to 'bitsandbytes/optim/lars.py')
-rw-r--r--bitsandbytes/optim/lars.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
new file mode 100644
index 0000000..40dede7
--- /dev/null
+++ b/bitsandbytes/optim/lars.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+
+from torch.optim import Optimizer
+from bitsandbytes.optim.optimizer import Optimizer1State
+
+class LARS(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, optim_bits=32, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+class LARS8bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+class LARS32bit(Optimizer1State):
+ def __init__(self, params, lr, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, args=None,
+ min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ if momentum == 0:
+ raise NotImplementError(f'LARS without momentum is not supported!')
+ super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
+ weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+
+
+class PytorchLARS(Optimizer):
+ def __init__(self, params, lr=0.01, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, max_unorm=0.02):
+ if lr < 0.0:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if momentum < 0.0:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if weight_decay < 0.0:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
+ weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ super(PytorchLARS, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(PytorchLARS, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('nesterov', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ d_p_list = []
+ momentum_buffer_list = []
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+ max_unorm = group['max_unorm']
+ lr = group['lr']
+
+ for p in group['params']:
+ if p.grad is None: continue
+
+ state = self.state[p]
+ d_p = p.grad
+ if weight_decay != 0:
+ d_p = d_p.add(param, alpha=weight_decay)
+
+ if momentum != 0:
+ buf = state.get('momentum_buffer', None)
+
+ if buf is None:
+ buf = torch.clone(d_p).detach()
+ state['momentum_buffer']= buf
+ else:
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
+
+ if nesterov:
+ update = d_p + buf*momentum
+ else:
+ update = buf
+
+ update_scale = 1.0
+ if max_unorm > 0.0:
+ assert p.dtype == torch.float32
+ pnorm = torch.norm(p.detach())
+ unorm = torch.norm(update)
+ if unorm > max_unorm*pnorm:
+ update_scale = max_unorm*pnorm/unorm
+
+ p.add_(update, alpha=-lr*update_scale)
+
+ return loss