summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/adam.py
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes/optim/adam.py')
-rw-r--r--bitsandbytes/optim/adam.py179
1 files changed, 128 insertions, 51 deletions
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index ed1b9f0..5cfaa28 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -1,6 +1,6 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-#
-# This source code is licensed under the MIT license found in the
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
@@ -8,29 +8,97 @@ import os
import torch
import torch.distributed as dist
-from bitsandbytes.optim.optimizer import Optimizer2State
+
import bitsandbytes.functional as F
+from bitsandbytes.optim.optimizer import Optimizer2State
+
class Adam(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, optim_bits=32, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam, self).__init__('adam', params, lr, betas, eps,
- weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ optim_bits=32,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ optim_bits,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam8bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam8bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 8,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
+
class Adam32bit(Optimizer2State):
- def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0, amsgrad=False, args=None,
- min_8bit_size=4096, percentile_clipping=100, block_wise=True):
- super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
- weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ args=None,
+ min_8bit_size=4096,
+ percentile_clipping=100,
+ block_wise=True,
+ ):
+ super(Adam32bit, self).__init__(
+ "adam",
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ 32,
+ args,
+ min_8bit_size,
+ percentile_clipping,
+ block_wise,
+ )
class AnalysisAdam(torch.optim.Optimizer):
@@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps=1e-8,
weight_decay=0,
amsgrad=False,
- bnb_analysis='dynamic-blockwise',
- savedir=None
+ bnb_analysis="dynamic-blockwise",
+ savedir=None,
):
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
@@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
- state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device)
- state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device)
+ state["abserrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["relerrors"] = torch.zeros(
+ (256, 256), device=p_data_fp32.device
+ )
+ state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
- e = state['abserrors']
- rele = state['relerrors']
- counts = state['counts']
+ e = state["abserrors"]
+ rele = state["relerrors"]
+ counts = state["counts"]
if group["weight_decay"] != 0:
p_data_fp32.add_(
@@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"]
-
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
- update_fp32 = exp_avg/denom
+ update_fp32 = exp_avg / denom
- if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000:
+ if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small
- p_data_fp32 += -step_size*update_fp32
+ p_data_fp32 += -step_size * update_fp32
else:
- if self.analysis == 'dynamic-blockwise':
+ if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2)
- elif self.analysis == 'dynamic':
+ elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'linear':
+ elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2)
- elif self.analysis == 'quantile':
+ elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
- elif self.analysis == 'my-quantization-routine':
+ elif self.analysis == "my-quantization-routine":
pass
# 1. get code
# 2. quantize
# 3. dequantize
# Error will be calculated automatically!
else:
- raise ValueError(f'Invalid analysis value: {self.analysis}!')
+ raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"])
- update_8bit = state1/denom
+ update_8bit = state1 / denom
- abserr = torch.abs(update_8bit-update_fp32)
- relerr = abserr/torch.abs(update_fp32+1e-6)
+ abserr = torch.abs(update_8bit - update_fp32)
+ relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
- F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr))
-
- p_data_fp32 += -step_size*update_fp32
+ F.histogram_scatter_add_2d(
+ counts, C1.int(), C2.int(), torch.ones_like(abserr)
+ )
+ p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0:
- if self.savedir != '' and state['step'] % 100 == 0:
- if not os.path.exists(self.savedir): os.makedirs(self.savedir)
- shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
- pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
- pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
- pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
+ if self.savedir != "" and state["step"] % 100 == 0:
+ if not os.path.exists(self.savedir):
+ os.makedirs(self.savedir)
+ shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
+ pathe = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
+ )
+ pathrele = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
+ )
+ pathcounts = os.path.join(
+ self.savedir, f"{p_id}_{shapestr}_counts.pkl"
+ )
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)
@@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
-
-
return loss