From 4e19348d7f151800afb080b83441e71747ea476d Mon Sep 17 00:00:00 2001 From: JieShenAI <2360467524@qq.com> Date: Fri, 20 Jun 2025 09:20:43 +0800 Subject: [PATCH] Make this code run properly in non-distributed mode (such as Debug or single-machine training) to avoid the error: `ValueError: Default process group has not been initialized`. --- FlagEmbedding/abc/finetune/embedder/AbsDataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py index a38a8a5b..f40949a1 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from torch.utils.data import Dataset from transformers import ( - PreTrainedTokenizer, + PreTrainedTokenizer, DataCollatorWithPadding, TrainerCallback, TrainerState, @@ -63,7 +63,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: Loaded HF dataset. """ - if dist.get_rank() == 0: + get_rank_safe = lambda: dist.get_rank() if dist.is_initialized() else 0 + if get_rank_safe() == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) @@ -342,7 +343,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: The loaded dataset. """ - if dist.get_rank() == 0: + get_rank_safe = lambda: dist.get_rank() if dist.is_initialized() else 0 + if get_rank_safe() == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)