summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/lars.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim/lars.py')
-rw-r--r--bitsandbytes/optim/lars.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index c6cf5c6..8a89fb0 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -24,7 +24,9 @@ class LARS(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS, self).__init__(
"lars",
params,
@@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS8bit, self).__init__(
"lars",
params,
@@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
max_unorm=0.02,
):
if momentum == 0:
- raise NotImplementedError(f"LARS without momentum is not supported!")
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
super(LARS32bit, self).__init__(
"lars",
params,
@@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
defaults = dict(
lr=lr,
@@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ raise ValueError(
+ "Nesterov momentum requires a momentum and zero dampening"
+ )
super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state):