summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/optim/adam.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index f00e5db..c1f455f 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State
+import bitsandbytes.functional as F
class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
@@ -28,7 +29,7 @@ class Adam32bit(Optimizer2State):
class AnalysisAdam(torch.optim.Optimizer):
- """Implements 8-bit Adam and performs error analysis.
+ """Adam that performs 8-bit vs 32-bit error analysis.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
@@ -190,6 +191,11 @@ class AnalysisAdam(torch.optim.Optimizer):
state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2)
+ elif self.analysis == 'my-quantization-routine':
+ # 1. get code
+ # 2. quantize
+ # 3. dequantize
+ # Error will be calculated automatically!
else:
raise ValueError(f'Invalid analysis value: {self.analysis}!')