From 56f527484814e3a9cfeb1d615a772d4a746bb071 Mon Sep 17 00:00:00 2001 From: Konstantin Ditschuneit Date: Sun, 31 Oct 2021 16:38:38 +0100 Subject: Add missing imports to adam --- bitsandbytes/optim/adam.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index eb951ee..f3e5e81 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -2,7 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +import math +import os + import torch +import torch.distributed as dist from bitsandbytes.optim.optimizer import Optimizer2State import bitsandbytes.functional as F @@ -220,9 +225,9 @@ class AnalysisAdam(torch.optim.Optimizer): if self.savedir != '' and state['step'] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) - pathe = join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') - pathrele = join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') - pathcounts = join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') + pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') + pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') + pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) -- cgit v1.2.3 From 67a1283501fa24d346f8e8efb4fc888a9ed8d193 Mon Sep 17 00:00:00 2001 From: Robin Schmidt Date: Mon, 15 Nov 2021 17:27:02 +0100 Subject: [FIX] passing of sparse in StableEmbedding --- bitsandbytes/nn/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'bitsandbytes') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index bf0945c..ce2f3a4 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -15,8 +15,8 @@ from bitsandbytes.optim import GlobalOptimManager class StableEmbedding(torch.nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = True, _weight: Optional[Tensor] = None) -> None: - super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight) + sparse: bool = False, _weight: Optional[Tensor] = None) -> None: + super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) self.norm = torch.nn.LayerNorm(embedding_dim) GlobalOptimManager.get_instance().register_parameters(self.weight) GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32) -- cgit v1.2.3