summaryrefslogtreecommitdiff
path: root/bitsandbytes/optim/adam.py
diff options
context:
space:
mode:
authorTim Dettmers <TimDettmers@users.noreply.github.com>2021-11-15 07:58:44 -0800
committerGitHub <noreply@github.com>2021-11-15 07:58:44 -0800
commit037022e878974b5dfeb354098a467a46618f9d85 (patch)
tree338d59d52520978111853dcc72b9f54500f02123 /bitsandbytes/optim/adam.py
parent22b2877c7f8277317a073ea7cf49231d33fe79fd (diff)
parent56f527484814e3a9cfeb1d615a772d4a746bb071 (diff)
Merge pull request #9 from ditschuk/fix_adam_imports
Add missing imports to adam
Diffstat (limited to 'bitsandbytes/optim/adam.py')
-rw-r--r--bitsandbytes/optim/adam.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index eb951ee..f3e5e81 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -2,7 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+
+import math
+import os
+
import torch
+import torch.distributed as dist
from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F
@@ -220,9 +225,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != '' and state['step'] % 100 == 0:
if not os.path.exists(self.savedir): os.makedirs(self.savedir)
shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
- pathe = join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
- pathrele = join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
- pathcounts = join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
+ pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
+ pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
+ pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)