summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r--bitsandbytes/optim/__init__.py6
-rw-r--r--bitsandbytes/optim/adagrad.py126
-rw-r--r--bitsandbytes/optim/adam.py198
-rw-r--r--bitsandbytes/optim/adamw.py104
-rw-r--r--bitsandbytes/optim/lamb.py117
-rw-r--r--bitsandbytes/optim/lars.py183
-rw-r--r--bitsandbytes/optim/optimizer.py622
-rw-r--r--bitsandbytes/optim/rmsprop.py121
-rw-r--r--bitsandbytes/optim/sgd.py109
9 files changed, 1185 insertions, 401 deletions
diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py
index 42b5bc0..a76d717 100644
--- a/bitsandbytes/optim/__init__.py
+++ b/bitsandbytes/optim/__init__.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.cextension import COMPILED_WITH_CUDA
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py
index 4f51250..7e2f566 100644
--- a/bitsandbytes/optim/adagrad.py
+++ b/bitsandbytes/optim/adagrad.py
@@ -1,54 +1,132 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class Adagrad(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
- super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise ValueError("Lr Decay != 0.0 not supported!")
+ super(Adagrad, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adagrad8bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=8,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
+ raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise
- super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ super(Adagrad8bit, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adagrad32bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
- optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ lr_decay=0,
+ weight_decay=0,
+ initial_accumulator_value=0,
+ eps=1e-10,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0:
- raise ValueError('Initial accumulator value != 0.0 not supported!')
+ raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0:
- raise ValueError('Lr Decay != 0.0 not supported!')
- super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise ValueError("Lr Decay != 0.0 not supported!")
+ super(Adagrad32bit, self).__init__(
+ "adagrad",
+ params,
+ lr,
+ (0.0, 0.0),
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index ed1b9f0..3634971 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
@@ -8,29 +8,97 @@ import os
import torch
import torch.distributed as dist
-from bitsandbytes.optim.optimizer import Optimizer2State
+
import bitsandbytes.functional as F
+from bitsandbytes.optim.optimizer import Optimizer2State
+
class Adam(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
class AnalysisAdam(torch.optim.Optimizer):
@@ -68,11 +136,15 @@ class AnalysisAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
amsgrad=False,
- bnb_analysis='dynamic-blockwise',
- savedir=None
+ bnb_analysis="dynamic-blockwise",
+ savedir=None,
):
defaults = dict(
- lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ amsgrad=amsgrad,
)
super(AnalysisAdam, self).__init__(params, defaults)
self.analysis = bnb_analysis
@@ -124,9 +196,15 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
- state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["abserrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["relerrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["counts"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -142,10 +220,12 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
- step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
- e = state['abserrors']
- rele = state['relerrors']
- counts = state['counts']
+ step_size = (
+ group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+ )
+ e = state["abserrors"]
+ rele = state["relerrors"]
+ counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
@@ -156,77 +236,91 @@ class AnalysisAdam(torch.optim.Optimizer):
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
-
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
- update_fp32 = exp_avg/denom
+ update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000:
+ if (
+ p_data_fp32.numel() <= 8192
+ or p_data_fp32.numel() > 50000 * 1000
+ ):
# embedding layer or too small
- p_data_fp32 += -step_size*update_fp32
+ p_data_fp32 += -step_size * update_fp32
else:
- if self.analysis == 'dynamic-blockwise':
+ if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2)
- elif self.analysis == 'dynamic':
+ elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'linear':
+ elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'quantile':
+ elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
- elif self.analysis == 'my-quantization-routine':
+ elif self.analysis == "my-quantization-routine":
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f'Invalid analysis value: {self.analysis}!')
+ raise ValueError(
+ f"Invalid analysis value: {self.analysis}!"
+ )
denom = state2.sqrt().add_(group["eps"])
- update_8bit = state1/denom
+ update_8bit = state1 / denom
- abserr = torch.abs(update_8bit-update_fp32)
- relerr = abserr/torch.abs(update_fp32+1e-6)
+ abserr = torch.abs(update_8bit - update_fp32)
+ relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
- F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
-
- p_data_fp32 += -step_size*update_fp32
+ F.histogram_scatter_add_2d(
+ counts, C1.int(), C2.int(), torch.ones_like(abserr)
+ )
+ p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0:
- if self.savedir != '' and state['step'] % 100 == 0:
- if not os.path.exists(self.savedir): os.makedirs(self.savedir)
- shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
- pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
- pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
- pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
+ if self.savedir != "" and state["step"] % 100 == 0:
+ if not os.path.exists(self.savedir):
+ os.makedirs(self.savedir)
+ shapestr = "_".join(
+ [str(dim) for dim in p_data_fp32.shape]
+ )
+ pathe = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
+ )
+ pathrele = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
+ )
+ pathcounts = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_counts.pkl"
+ )
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
@@ -234,6 +328,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
-
-
return loss
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index c4f0355..d0b3bde 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -1,27 +1,93 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+
class AdamW(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class AdamW8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
-class AdamW32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=1e-2, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+class AdamW32bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=1e-2,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(AdamW32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py
index 58cc13d..8f365f7 100644
--- a/bitsandbytes/optim/lamb.py
+++ b/bitsandbytes/optim/lamb.py
@@ -1,28 +1,105 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+
class LAMB(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
-class LAMB8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
-class LAMB32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
- super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
+class LAMB8bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB8bit, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
+class LAMB32bit(Optimizer2State):
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ adam_w_mode=True,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=False,
+ max_unorm=1.0,
+ ):
+ super(LAMB32bit, self).__init__(
+ "lamb",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ max_unorm=1.0,
+ )
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index 912520d..8a89fb0 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -1,60 +1,154 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
-
from torch.optim import Optimizer
+
from bitsandbytes.optim.optimizer import Optimizer1State
+
class LARS(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
+ super(LARS, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
+ super(LARS8bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
+
class LARS32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ max_unorm=0.02,
+ ):
if momentum == 0:
- raise NotImplementedError(f'LARS without momentum is not supported!')
- super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
+ raise NotImplementedError(
+ f"LARS without momentum is not supported!"
+ )
+ super(LARS32bit, self).__init__(
+ "lars",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ max_unorm=max_unorm,
+ block_wise=False,
+ )
class PytorchLARS(Optimizer):
- def __init__(self, params, lr=0.01, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, max_unorm=0.02):
+ def __init__(
+ self,
+ params,
+ lr=0.01,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ max_unorm=0.02,
+ ):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
-
- defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
- weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+
+ defaults = dict(
+ lr=lr,
+ momentum=momentum,
+ dampening=dampening,
+ weight_decay=weight_decay,
+ nesterov=nesterov,
+ max_unorm=max_unorm,
+ )
if nesterov and (momentum <= 0 or dampening != 0):
- raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+ raise ValueError(
+ "Nesterov momentum requires a momentum and zero dampening"
+ )
super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups:
- group.setdefault('nesterov', False)
+ group.setdefault("nesterov", False)
@torch.no_grad()
def step(self, closure=None):
@@ -73,15 +167,16 @@ class PytorchLARS(Optimizer):
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
- weight_decay = group['weight_decay']
- momentum = group['momentum']
- dampening = group['dampening']
- nesterov = group['nesterov']
- max_unorm = group['max_unorm']
- lr = group['lr']
+ weight_decay = group["weight_decay"]
+ momentum = group["momentum"]
+ dampening = group["dampening"]
+ nesterov = group["nesterov"]
+ max_unorm = group["max_unorm"]
+ lr = group["lr"]
- for p in group['params']:
- if p.grad is None: continue
+ for p in group["params"]:
+ if p.grad is None:
+ continue
state = self.state[p]
d_p = p.grad
@@ -89,16 +184,16 @@ class PytorchLARS(Optimizer):
d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0:
- buf = state.get('momentum_buffer', None)
+ buf = state.get("momentum_buffer", None)
if buf is None:
buf = torch.clone(d_p).detach()
- state['momentum_buffer']= buf
+ state["momentum_buffer"] = buf
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
- update = d_p + buf*momentum
+ update = d_p + buf * momentum
else:
update = buf
@@ -107,9 +202,9 @@ class PytorchLARS(Optimizer):
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
- if unorm > max_unorm*pnorm:
- update_scale = max_unorm*pnorm/unorm
+ if unorm > max_unorm * pnorm:
+ update_scale = max_unorm * pnorm / unorm
- p.add_(update, alpha=-lr*update_scale)
+ p.add_(update, alpha=-lr * update_scale)
return loss
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 5a5bb1e..4fb30cd 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -1,13 +1,16 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from collections import abc as container_abcs
+from collections import defaultdict
+from copy import deepcopy
+from itertools import chain
+
import torch
+
import bitsandbytes.functional as F
-from copy import deepcopy
-from itertools import chain
-from collections import defaultdict, abc as container_abcs
class MockArgs(object):
def __init__(self, initial_data):
@@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance = None
def __init__(self):
- raise RuntimeError('Call get_instance() instead')
+ raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.pid2config = {}
@@ -38,15 +41,19 @@ class GlobalOptimManager(object):
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
- param_groups = [{'params': param_groups}]
+ param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups):
- for p_index, p in enumerate(group['params']):
+ for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config:
- self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
+ self.index2config[(group_index, p_index)] = self.pid2config[
+ id(p)
+ ]
- def override_config(self, parameters, key=None, value=None, key_value_dict=None):
- '''
+ def override_config(
+ self, parameters, key=None, value=None, key_value_dict=None
+ ):
+ """
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
@@ -63,7 +70,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
- '''
+ """
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
@@ -75,16 +82,16 @@ class GlobalOptimManager(object):
if key_value_dict is not None:
for p in parameters:
- if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
- else: self.pid2config[id(p)] = key_value_dict
+ if id(p) in self.pid2config:
+ self.pid2config[id(p)].update(key_value_dict)
+ else:
+ self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
-
class Optimizer8bit(torch.optim.Optimizer):
-
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
@@ -92,23 +99,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
- ['qmap1', 'qmap2',
- 'max1', 'max2',
- 'new_max1', 'new_max2',
- 'state1', 'state2',
- 'gnorm_vec', 'absmax1', 'absmax2',
- 'unorm_vec'])
-
- if optim_bits == 8: self.fill_qmap()
+ [
+ "qmap1",
+ "qmap2",
+ "max1",
+ "max2",
+ "new_max1",
+ "new_max2",
+ "state1",
+ "state2",
+ "gnorm_vec",
+ "absmax1",
+ "absmax2",
+ "unorm_vec",
+ ]
+ )
+
+ if optim_bits == 8:
+ self.fill_qmap()
def fill_qmap(self):
- self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
- self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
+ self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
+ self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
-
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@@ -120,21 +136,29 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
- saved_groups = state_dict['param_groups']
+ saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups):
- raise ValueError("loaded state dict has a different number of "
- "parameter groups")
- param_lens = (len(g['params']) for g in groups)
- saved_lens = (len(g['params']) for g in saved_groups)
+ raise ValueError(
+ "loaded state dict has a different number of "
+ "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
# Update the state
- id_map = {old_id: p for old_id, p in
- zip(chain.from_iterable((g['params'] for g in saved_groups)),
- chain.from_iterable((g['params'] for g in groups)))}
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
@@ -161,7 +185,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
- for k, v in state_dict['state'].items():
+ for k, v in state_dict["state"].items():
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
@@ -170,15 +194,17 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
- new_group['params'] = group['params']
+ new_group["params"] = group["params"]
return new_group
+
param_groups = [
- update_group(g, ng) for g, ng in zip(groups, saved_groups)]
- self.__setstate__({'state': state, 'param_groups': param_groups})
+ update_group(g, ng) for g, ng in zip(groups, saved_groups)
+ ]
+ self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p in self.state:
values = self.state[p]
for k, v in values.items():
@@ -189,17 +215,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
- assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
+ assert isinstance(pmodule, torch.Tensor) or isinstance(
+ pmodule, torch.Parameter
+ )
found = False
for gindex, group in enumerate(self.param_groups):
- if found: break
- for pindex, p in enumerate(group['params']):
- if found: break
+ if found:
+ break
+ for pindex, p in enumerate(group["params"]):
+ if found:
+ break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
- self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
+ self.mng.index2config[
+ (gindex, pindex)
+ ] = self.mng.pid2config[id(p)]
found = True
@torch.no_grad()
@@ -219,11 +251,11 @@ class Optimizer8bit(torch.optim.Optimizer):
if not self.initialized:
self.check_overrides()
- self.to_gpu() # needed for fairseq pure fp16 training
+ self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
for gindex, group in enumerate(self.param_groups):
- for pindex, p in enumerate(group['params']):
+ for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
@@ -236,58 +268,76 @@ class Optimizer8bit(torch.optim.Optimizer):
def get_config(self, gindex, pindex, group):
config = {}
- config['betas'] = group['betas']
- config['eps'] = group['eps']
- config['weight_decay'] = group['weight_decay']
- config['lr'] = group['lr']
- config['optim_bits'] = self.args.optim_bits
- config['min_8bit_size'] = self.args.min_8bit_size
- config['percentile_clipping'] = self.args.percentile_clipping
- config['block_wise'] = self.args.block_wise
- config['max_unorm'] = self.args.max_unorm
- config['skip_zeros'] = self.args.skip_zeros
+ config["betas"] = group["betas"]
+ config["eps"] = group["eps"]
+ config["weight_decay"] = group["weight_decay"]
+ config["lr"] = group["lr"]
+ config["optim_bits"] = self.args.optim_bits
+ config["min_8bit_size"] = self.args.min_8bit_size
+ config["percentile_clipping"] = self.args.percentile_clipping
+ config["block_wise"] = self.args.block_wise
+ config["max_unorm"] = self.args.max_unorm
+ config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
- raise NotImplementedError(f'init_state method needs to be overidden')
+ raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex):
- raise NotImplementedError(f'The update_step method needs to be overidden')
+ raise NotImplementedError(
+ f"The update_step method needs to be overidden"
+ )
+
class Optimizer2State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
# format: '(beta1, beta2)'
- betas = betas.replace('(', '').replace(')', '').strip().split(',')
+ betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas]
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -299,50 +349,93 @@ class Optimizer2State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
-
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["step"] = 0
+
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
- self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap2'] = self.name2qmap['udynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+ self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
+ p.device
+ )
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ state["state2"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap2"] = self.name2qmap["udynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
- state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
+ state["absmax2"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
-
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
-
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["new_max2"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
+
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -351,63 +444,128 @@ class Optimizer2State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
- config['weight_decay'], gnorm_scale=gnorm_scale,
- unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ state["state2"],
+ config["betas"][1],
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["max1"],
+ state["max2"],
+ state["new_max1"],
+ state["new_max2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ unorm_vec=state["unorm_vec"]
+ if config["max_unorm"] > 0.0
+ else None,
+ max_unorm=config["max_unorm"],
+ )
# swap maxes
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- state['max2'], state['new_max2'] = state['new_max2'], state['max2']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ state["state2"],
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ state["qmap2"],
+ state["absmax1"],
+ state["absmax2"],
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
class Optimizer1State(Optimizer8bit):
- def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
- weight_decay=0.0, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0,
- skip_zeros=False):
+ def __init__(
+ self,
+ optimizer_name,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.0),
+ eps=1e-8,
+ weight_decay=0.0,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ max_unorm=0.0,
+ skip_zeros=False,
+ ):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
- raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
+ raise ValueError(
+ f"Invalid beta parameter at index {i}: {betas[i]}"
+ )
if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(lr=lr, betas=betas, eps=eps,
- weight_decay=weight_decay)
+ raise ValueError(
+ "Invalid weight_decay value: {}".format(weight_decay)
+ )
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
- args['optim_bits'] = optim_bits
- args['percentile_clipping'] = 100
- args['min_8bit_size'] = min_8bit_size
- args['percentile_clipping'] = percentile_clipping
- args['block_wise'] = block_wise
- args['max_unorm'] = max_unorm
- args['skip_zeros'] = skip_zeros
+ args["optim_bits"] = optim_bits
+ args["percentile_clipping"] = 100
+ args["min_8bit_size"] = min_8bit_size
+ args["percentile_clipping"] = percentile_clipping
+ args["block_wise"] = block_wise
+ args["max_unorm"] = max_unorm
+ args["skip_zeros"] = skip_zeros
self.args = MockArgs(args)
else:
@@ -419,43 +577,67 @@ class Optimizer1State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
- if config['optim_bits'] == 32:
+ if config["optim_bits"] == 32:
dtype = torch.float32
- elif config['optim_bits'] == 8:
+ elif config["optim_bits"] == 8:
dtype = torch.uint8
- else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
+ else:
+ raise NotImplementedError(
+ f'Amount of optimizer bits not supported: {config["optim_bits"]}'
+ )
- if p.numel() < config['min_8bit_size']: dtype = torch.float32
+ if p.numel() < config["min_8bit_size"]:
+ dtype = torch.float32
state = self.state[p]
- state['step'] = 0
-
- if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
+ state["step"] = 0
+
+ if dtype == torch.float32 or (
+ dtype == torch.uint8 and p.numel() < 4096
+ ):
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.float32,
+ device=p.device,
+ )
elif dtype == torch.uint8:
- if state['step'] == 0:
- if 'dynamic' not in self.name2qmap: self.fill_qmap()
- self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
-
- state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
- state['qmap1'] = self.name2qmap['dynamic']
-
- if config['block_wise']:
+ if state["step"] == 0:
+ if "dynamic" not in self.name2qmap:
+ self.fill_qmap()
+ self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
+ p.device
+ )
+
+ state["state1"] = torch.zeros_like(
+ p,
+ memory_format=torch.preserve_format,
+ dtype=torch.uint8,
+ device=p.device,
+ )
+ state["qmap1"] = self.name2qmap["dynamic"]
+
+ if config["block_wise"]:
n = p.numel()
- blocks = n//2048
+ blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
- state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
+ state["absmax1"] = torch.zeros(
+ (blocks,), dtype=torch.float32, device=p.device
+ )
else:
- state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
- state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
-
- if config['percentile_clipping'] < 100:
- state['gnorm_vec'] = torch.zeros((100,), device=p.device)
+ state["max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
+ state["new_max1"] = torch.zeros(
+ (1,), dtype=torch.float32, device=p.device
+ )
- if config['max_unorm'] > 0.0:
- state['unorm_vec'] = torch.zeros((1,), device=p.device)
+ if config["percentile_clipping"] < 100:
+ state["gnorm_vec"] = torch.zeros((100,), device=p.device)
+ if config["max_unorm"] > 0.0:
+ state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
@@ -464,29 +646,77 @@ class Optimizer1State(Optimizer8bit):
config = self.get_config(gindex, pindex, group)
- state['step'] += 1
- step = state['step']
+ state["step"] += 1
+ step = state["step"]
- if config['percentile_clipping'] < 100:
- current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
+ if config["percentile_clipping"] < 100:
+ current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
+ grad, state["gnorm_vec"], step, config["percentile_clipping"]
+ )
else:
gnorm_scale = 1.0
- if state['state1'].dtype == torch.float:
- F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
- None, 0.0, config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'],
- skip_zeros=config['skip_zeros'])
-
- elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
- F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
- config['weight_decay'], gnorm_scale,
- state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
-
- state['max1'], state['new_max1'] = state['new_max1'], state['max1']
- elif state['state1'].dtype == torch.uint8 and config['block_wise']:
- F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
- config['eps'], step, config['lr'],
- state['qmap1'], None, state['absmax1'], None,
- config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros'])
+ if state["state1"].dtype == torch.float:
+ F.optimizer_update_32bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ config["betas"][0],
+ config["eps"],
+ step,
+ config["lr"],
+ None,
+ 0.0,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ skip_zeros=config["skip_zeros"],
+ )
+
+ elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
+ F.optimizer_update_8bit(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["max1"],
+ None,
+ state["new_max1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale,
+ state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
+ max_unorm=config["max_unorm"],
+ )
+
+ state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
+ elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
+ F.optimizer_update_8bit_blockwise(
+ self.optimizer_name,
+ grad,
+ p,
+ state["state1"],
+ None,
+ config["betas"][0],
+ config["betas"][1],
+ config["eps"],
+ step,
+ config["lr"],
+ state["qmap1"],
+ None,
+ state["absmax1"],
+ None,
+ config["weight_decay"],
+ gnorm_scale=gnorm_scale,
+ skip_zeros=config["skip_zeros"],
+ )
diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py
index 0f1ffaa..7ddb12c 100644
--- a/bitsandbytes/optim/rmsprop.py
+++ b/bitsandbytes/optim/rmsprop.py
@@ -1,36 +1,115 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class RMSprop(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class RMSprop8bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop8bit, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class RMSprop32bit(Optimizer1State):
- def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr=1e-2,
+ alpha=0.99,
+ eps=1e-8,
+ weight_decay=0,
+ momentum=0,
+ centered=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if alpha == 0:
- raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
+ raise NotImplementedError(
+ f"RMSprop with alpha==0.0 is not supported!"
+ )
if centered:
- raise NotImplementedError(f'Centered RMSprop is not supported!')
- super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"Centered RMSprop is not supported!")
+ super(RMSprop32bit, self).__init__(
+ "rmsprop",
+ params,
+ lr,
+ (alpha, momentum),
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py
index 0529879..f7b8934 100644
--- a/bitsandbytes/optim/sgd.py
+++ b/bitsandbytes/optim/sgd.py
@@ -1,32 +1,99 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
+
class SGD(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class SGD8bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD8bit, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class SGD32bit(Optimizer1State):
- def __init__(self, params, lr, momentum=0, dampening=0,
- weight_decay=0, nesterov=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
+ def __init__(
+ self,
+ params,
+ lr,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
if momentum == 0:
- raise NotImplementedError(f'SGD without momentum is not supported!')
- super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ raise NotImplementedError(f"SGD without momentum is not supported!")
+ super(SGD32bit, self).__init__(
+ "momentum",
+ params,
+ lr,
+ (momentum, dampening),
+ 0.0,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )