summaryrefslogtreecommitdiff
path: root/bitsandbytes
diff options
context:
space:
mode:
Diffstat (limited to 'bitsandbytes')
-rw-r--r--bitsandbytes/nn/modules.py4
-rw-r--r--bitsandbytes/optim/adam.py11
2 files changed, 10 insertions, 5 deletions
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index bf0945c..ce2f3a4 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -15,8 +15,8 @@ from bitsandbytes.optim import GlobalOptimManager
class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
- sparse: bool = True, _weight: Optional[Tensor] = None) -> None:
- super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight)
+ sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
+ super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_parameters(self.weight)
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py
index 1e93a60..ed1b9f0 100644
--- a/bitsandbytes/optim/adam.py
+++ b/bitsandbytes/optim/adam.py
@@ -2,7 +2,12 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+
+import math
+import os
+
import torch
+import torch.distributed as dist
from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F
@@ -219,9 +224,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != '' and state['step'] % 100 == 0:
if not os.path.exists(self.savedir): os.makedirs(self.savedir)
shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape])
- pathe = join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
- pathrele = join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
- pathcounts = join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
+ pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl')
+ pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl')
+ pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl')
torch.save(e, pathe)
torch.save(rele, pathrele)
torch.save(counts, pathcounts)