summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/lars.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim/lars.py')
-rw-r--r--bitsandbytes/optim/lars.py167
1 files changed, 126 insertions, 41 deletions
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index 912520d..c6cf5c6 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -1,43 +1,121 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
-
from torch.optim import Optimizer
+
from bitsandbytes.optim.optimizer import Optimizer1State
+
class LARS(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(f"LARS without momentum is not supported!")
+ super(LARS, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(f"LARS without momentum is not supported!")
+ super(LARS8bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(f"LARS without momentum is not supported!")
+ super(LARS32bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
class PytorchLARS(Optimizer):
- def __init__(self, params, lr=0.01, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr=0.01,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ max_unorm=0.02,
+ ):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
@@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
- weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ defaults = dict(
+ lr=lr,
+ momentum=momentum,
+ dampening=dampening,
+ weight_decay=weight_decay,
+ nesterov=nesterov,
+ max_unorm=max_unorm,
+ )
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(PytorchLARS, self).__init__(params, defaults)
@@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups:
- group.setdefault('nesterov', False)
+ group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
@@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
- weight_decay = group['weight_decay']
- momentum = group['momentum']
- dampening = group['dampening']
- nesterov = group['nesterov']
- max_unorm = group['max_unorm']
- lr = group['lr']
+ weight_decay = group["weight_decay"]
+ momentum = group["momentum"]
+ dampening = group["dampening"]
+ nesterov = group["nesterov"]
+ max_unorm = group["max_unorm"]
+ lr = group["lr"]
- for p in group['params']:
- if p.grad is None: continue
+ for p in group["params"]:
+ if p.grad is None:
+ continue
state = self.state[p]
d_p = p.grad
@@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
- buf = state.get('momentum_buffer', None)
+ buf = state.get("momentum_buffer", None)
if buf is None:
buf = torch.clone(d_p).detach()
- state['momentum_buffer']= buf
+ state["momentum_buffer"] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
- update = d_p + buf*momentum
+ update = d_p + buf * momentum
else:
update = buf
@@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
- if unorm > max_unorm*pnorm:
- update_scale = max_unorm*pnorm/unorm
+ if unorm > max_unorm * pnorm:
+ update_scale = max_unorm * pnorm / unorm
- p.add_(update, alpha=-lr*update_scale)
+ p.add_(update, alpha=-lr * update_scale)
return loss