summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim/optimizer.py')
-rw-r--r--bitsandbytes/optim/optimizer.py31
1 files changed, 28 insertions, 3 deletions
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index cfbd72e..5a5bb1e 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -26,6 +26,7 @@ class GlobalOptimManager(object):
self.index2config = {}
self.optimizer = None
self.uses_config_override = False
+ self.module_weight_config_triple = []
@classmethod
def get_instance(cls):
@@ -77,12 +78,16 @@ class GlobalOptimManager(object):
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
else: self.pid2config[id(p)] = key_value_dict
+ def register_module_override(self, module, param_name, config):
+ self.module_weight_config_triple.append((module, param_name, config))
+
+
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
- self.checked_if_on_gpu = False
+ self.initialized = False
self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance()
@@ -172,7 +177,6 @@ class Optimizer8bit(torch.optim.Optimizer):
self.__setstate__({'state': state, 'param_groups': param_groups})
def to_gpu(self):
- self.checked_if_on_gpu = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p in self.state:
@@ -181,6 +185,23 @@ class Optimizer8bit(torch.optim.Optimizer):
if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device)
+ def check_overrides(self):
+ for module, attr, config in self.mng.module_weight_config_triple:
+ pmodule = getattr(module, attr)
+ assert pmodule is not None
+ assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
+ found = False
+ for gindex, group in enumerate(self.param_groups):
+ if found: break
+ for pindex, p in enumerate(group['params']):
+ if found: break
+ if id(p) == id(pmodule):
+ # found the matching parameter
+ # init override
+ self.mng.pid2config[id(p)] = config
+ self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
+ found = True
+
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
@@ -196,7 +217,11 @@ class Optimizer8bit(torch.optim.Optimizer):
overflows = []
- if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
+ if not self.initialized:
+ self.check_overrides()
+ self.to_gpu() # needed for fairseq pure fp16 training
+ self.initialized = True
+
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p.grad is None: