diff options
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r-- | bitsandbytes/optim/adagrad.py | 12 | ||||
-rw-r--r-- | bitsandbytes/optim/adam.py | 27 | ||||
-rw-r--r-- | bitsandbytes/optim/lars.py | 20 | ||||
-rw-r--r-- | bitsandbytes/optim/optimizer.py | 77 | ||||
-rw-r--r-- | bitsandbytes/optim/rmsprop.py | 12 |
5 files changed, 110 insertions, 38 deletions
diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 43e3973..7e2f566 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -23,7 +23,9 @@ class Adagrad(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: @@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: @@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 5cfaa28..3634971 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer): savedir=None, ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, ) super(AnalysisAdam, self).__init__(params, defaults) self.analysis = bnb_analysis @@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer): state["relerrors"] = torch.zeros( (256, 256), device=p_data_fp32.device ) - state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer): beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] @@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f"Invalid analysis value: {self.analysis}!") + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) pathe = os.path.join( self.savedir, f"{p_id}_{shapestr}_abserr.pkl" ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index c6cf5c6..8a89fb0 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -24,7 +24,9 @@ class LARS(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS, self).__init__( "lars", params, @@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS8bit, self).__init__( "lars", params, @@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS32bit, self).__init__( "lars", params, @@ -121,7 +127,9 @@ class PytorchLARS(Optimizer): if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict( lr=lr, @@ -132,7 +140,9 @@ class PytorchLARS(Optimizer): max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening" + ) super(PytorchLARS, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index b942e34..4fb30cd 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -46,9 +46,13 @@ class GlobalOptimManager(object): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[id(p)] + self.index2config[(group_index, p_index)] = self.pid2config[ + id(p) + ] - def override_config(self, parameters, key=None, value=None, key_value_dict=None): + def override_config( + self, parameters, key=None, value=None, key_value_dict=None + ): """ Overrides initial optimizer config for specific parameters. @@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer): if len(groups) != len(saved_groups): raise ValueError( - "loaded state dict has a different number of " "parameter groups" + "loaded state dict has a different number of " + "parameter groups" ) param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) @@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer): new_group["params"] = group["params"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups) + ] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[ - id(p) - ] + self.mng.index2config[ + (gindex, pindex) + ] = self.mng.pid2config[id(p)] found = True @torch.no_grad() @@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer): raise NotImplementedError(f"init_state method needs to be overidden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError(f"The update_step method needs to be overidden") + raise NotImplementedError( + f"The update_step method needs to be overidden" + ) class Optimizer2State(Optimizer8bit): @@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit): betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer2State, self).__init__(params, defaults, optim_bits) @@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = torch.zeros_like( p, memory_format=torch.preserve_format, @@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit): if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( + p.device + ) state["state1"] = torch.zeros_like( p, @@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit): (blocks,), dtype=torch.float32, device=p.device ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max1"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) - state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max2"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) @@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + unorm_vec=state["unorm_vec"] + if config["max_unorm"] > 0.0 + else None, max_unorm=config["max_unorm"], ) @@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit): raise ValueError("Invalid epsilon value: {}".format(eps)) for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer1State, self).__init__(params, defaults, optim_bits) @@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = torch.zeros_like( p, memory_format=torch.preserve_format, @@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit): if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) state["state1"] = torch.zeros_like( p, @@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit): (blocks,), dtype=torch.float32, device=p.device ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max1"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 679f783..7ddb12c 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -22,7 +22,9 @@ class RMSprop(Optimizer1State): block_wise=True, ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop, self).__init__( @@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State): block_wise=True, ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop8bit, self).__init__( @@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State): ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop32bit, self).__init__( |