diff options
author | Tim Dettmers <TimDettmers@users.noreply.github.com> | 2021-11-15 07:58:44 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-15 07:58:44 -0800 |
commit | 037022e878974b5dfeb354098a467a46618f9d85 (patch) | |
tree | 338d59d52520978111853dcc72b9f54500f02123 /bitsandbytes/optim/adam.py | |
parent | 22b2877c7f8277317a073ea7cf49231d33fe79fd (diff) | |
parent | 56f527484814e3a9cfeb1d615a772d4a746bb071 (diff) |
Merge pull request #9 from ditschuk/fix_adam_imports
Add missing imports to adam
Diffstat (limited to 'bitsandbytes/optim/adam.py')
-rw-r--r-- | bitsandbytes/optim/adam.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index eb951ee..f3e5e81 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -2,7 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +import math +import os + import torch +import torch.distributed as dist from bitsandbytes.optim.optimizer import Optimizer2State import bitsandbytes.functional as F @@ -220,9 +225,9 @@ class AnalysisAdam(torch.optim.Optimizer): if self.savedir != '' and state['step'] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) - pathe = join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') - pathrele = join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') - pathcounts = join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') + pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') + pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') + pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) |