-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Trainer] Fix distributed dataloader #8932
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1398,12 +1398,16 @@ def get_train_dataloader(self): | |
raise ValueError("We don't need train_dataset when should_load_dataset is False.") | ||
|
||
train_dataset = self.train_dataset | ||
if self.args.distributed_dataloader: | ||
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset) | ||
else: | ||
is_iterable_dataset = self._is_iterable_dataset(train_dataset) | ||
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset): | ||
train_dataset = self._remove_unused_columns(train_dataset, description="training") | ||
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader | ||
|
||
if self._is_iterable_dataset(train_dataset): | ||
if self.args.dataset_world_size > 1: | ||
if is_iterable_dataset: # For iterable dataset | ||
if self.args.dataset_world_size > 1 and train_dataset is not None: | ||
train_dataset = IterableDatasetShard( | ||
train_dataset, | ||
batch_size=self.args.per_device_train_batch_size, | ||
|
@@ -1412,24 +1416,28 @@ def get_train_dataloader(self): | |
process_index=self.args.dataset_rank, | ||
) | ||
|
||
if self.args.distributed_dataloader: | ||
logger.info("Training using DistDataLoader.") | ||
additional_configs = {"is_iterable_dataset": True} | ||
else: | ||
additional_configs = {} | ||
return _DataLoader( | ||
train_dataset, | ||
batch_size=self.args.per_device_train_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
**additional_configs, | ||
) | ||
else: | ||
train_sampler = self._get_train_sampler() | ||
if self.args.distributed_dataloader: | ||
logger.info("Training using DistDataLoader.") | ||
return _DataLoader( | ||
train_dataset, | ||
batch_sampler=train_sampler, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
) | ||
|
||
train_sampler = self._get_train_sampler() | ||
|
||
if self.args.distributed_dataloader: | ||
logger.info("Training using DistDataLoader.") | ||
|
||
return _DataLoader( | ||
train_dataset, | ||
batch_sampler=train_sampler, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
) | ||
|
||
def _get_eval_sampler(self, eval_dataset: Dataset): | ||
if eval_dataset is None or not has_length(eval_dataset): | ||
|
@@ -1476,54 +1484,48 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa | |
raise ValueError("We don't need eval_dataset when should_load_dataset is False.") | ||
|
||
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset | ||
|
||
if self.args.distributed_dataloader: | ||
is_iterable_dataset = self._is_iterable_dataset_distributed(eval_dataset) | ||
else: | ||
is_iterable_dataset = self._is_iterable_dataset(eval_dataset) | ||
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset): | ||
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") | ||
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader | ||
|
||
if self._is_iterable_dataset(eval_dataset): | ||
if self.args.dataset_world_size > 1: | ||
if is_iterable_dataset: | ||
if self.args.dataset_world_size > 1 and eval_dataset is not None: | ||
eval_dataset = IterableDatasetShard( | ||
eval_dataset, | ||
batch_size=self.args.per_device_eval_batch_size, | ||
drop_last=self.args.dataloader_drop_last, | ||
num_processes=self.args.dataset_world_size, | ||
process_index=self.args.dataset_rank, | ||
) | ||
|
||
if self.args.distributed_dataloader: | ||
return DistDataLoader( | ||
eval_dataset, | ||
batch_size=self.args.per_device_eval_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=0, | ||
eval=True, | ||
) | ||
logger.info("Eval using DistDataLoader.") | ||
additional_configs = {"eval": True, "is_iterable_dataset": True} | ||
else: | ||
return DataLoader( | ||
eval_dataset, | ||
batch_size=self.args.per_device_eval_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=0, | ||
) | ||
|
||
eval_sampler = self._get_eval_sampler(eval_dataset) | ||
|
||
if self.args.distributed_dataloader: | ||
logger.info("Eval using DistDataLoader.") | ||
|
||
return DistDataLoader( | ||
additional_configs = {} | ||
return _DataLoader( | ||
eval_dataset, | ||
batch_sampler=eval_sampler, | ||
batch_size=self.args.per_device_eval_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
eval=True, | ||
num_workers=0, | ||
**additional_configs, | ||
) | ||
else: | ||
return DataLoader( | ||
eval_sampler = self._get_eval_sampler(eval_dataset) | ||
if self.args.distributed_dataloader: | ||
logger.info("Eval using DistDataLoader.") | ||
additional_configs = {"eval": True} | ||
else: | ||
additional_configs = {} | ||
return _DataLoader( | ||
eval_dataset, | ||
batch_sampler=eval_sampler, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
**additional_configs, | ||
) | ||
|
||
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: | ||
|
@@ -1542,11 +1544,16 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: | |
if not self.args.should_load_dataset and test_dataset is not None: | ||
raise ValueError("We don't need test_dataset when should_load_dataset is False.") | ||
|
||
if self.args.distributed_dataloader: | ||
is_iterable_dataset = self._is_iterable_dataset_distributed(test_dataset) | ||
else: | ||
is_iterable_dataset = self._is_iterable_dataset(test_dataset) | ||
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset): | ||
test_dataset = self._remove_unused_columns(test_dataset, description="test") | ||
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader | ||
|
||
if self._is_iterable_dataset(test_dataset): | ||
if self.args.dataset_world_size > 1: | ||
if is_iterable_dataset: | ||
if self.args.dataset_world_size > 1 and test_dataset is not None: | ||
test_dataset = IterableDatasetShard( | ||
test_dataset, | ||
batch_size=self.args.per_device_eval_batch_size, | ||
|
@@ -1556,40 +1563,31 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: | |
) | ||
|
||
if self.args.distributed_dataloader: | ||
return DistDataLoader( | ||
test_dataset, | ||
batch_size=self.args.per_device_eval_batch_size * self.world_size, | ||
collate_fn=self.data_collator, # _get_collator_with_removed_columns | ||
num_workers=self.args.dataloader_num_workers, | ||
eval=True, | ||
) | ||
logger.info("Test using DistDataLoader.") | ||
additional_config = {"eval": True, "is_iterable_dataset": True} | ||
else: | ||
return DataLoader( | ||
test_dataset, | ||
batch_size=self.args.per_device_eval_batch_size * self.world_size, | ||
collate_fn=self.data_collator, # _get_collator_with_removed_columns | ||
num_workers=self.args.dataloader_num_workers, | ||
) | ||
|
||
test_sampler = self._get_eval_sampler(test_dataset) | ||
|
||
if self.args.distributed_dataloader: | ||
logger.info("Test using DistDataLoader.") | ||
|
||
# We use the same batch_size as for eval. | ||
return DistDataLoader( | ||
additional_config = {} | ||
return _DataLoader( | ||
test_dataset, | ||
batch_sampler=test_sampler, | ||
batch_size=self.args.per_device_eval_batch_size * self.world_size, | ||
collate_fn=self.data_collator, | ||
drop_last=self.args.dataloader_drop_last, | ||
eval=True, | ||
num_workers=self.args.dataloader_num_workers, | ||
**additional_config, | ||
) | ||
else: | ||
return DataLoader( | ||
test_sampler = self._get_eval_sampler(test_dataset) | ||
if self.args.distributed_dataloader: | ||
logger.info("Test using DistDataLoader.") | ||
additional_config = {"eval": True} | ||
else: | ||
additional_config = {} | ||
# We use the same batch_size as for eval. | ||
return _DataLoader( | ||
test_dataset, | ||
batch_sampler=test_sampler, | ||
collate_fn=self.data_collator, | ||
drop_last=self.args.dataloader_drop_last, | ||
**additional_config, | ||
) | ||
|
||
def create_optimizer_and_scheduler(self, num_training_steps: int): | ||
|
@@ -1694,6 +1692,8 @@ def _load_rng_state(self, checkpoint): | |
|
||
if self.args.use_hybrid_parallel: | ||
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state: | ||
if self.args.tensor_parallel_degree <= 1: | ||
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里触发hang住的原因,请文字说明清楚 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这行不会触发hang住,只是修了bug。如果非tp但是rng_state里面有tp的种子,加载起来会报错。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在PR描述里解释了 |
||
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker( | ||
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"] | ||
) | ||
|
@@ -3201,6 +3201,15 @@ def _get_collator_with_removed_columns( | |
def _is_iterable_dataset(self, dataset): | ||
return isinstance(dataset, paddle.io.IterableDataset) | ||
|
||
def _is_iterable_dataset_distributed(self, dataset): | ||
# For distributed dataloaer. | ||
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1]) | ||
if dist.get_world_size() > 1: | ||
dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NPU不支持bool类型通信,需要兼容 |
||
if is_iterable_dataset_tensor.item() == 1: | ||
return True | ||
return False | ||
|
||
def print_config(self, args=None, key=""): | ||
""" | ||
print config values | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -659,7 +659,7 @@ def setUp(self): | |
self.need_allclose = True | ||
self.rtol = 1e-7 | ||
|
||
self.run_pretrain_file = "llm/llama/run_pretrain.py" | ||
self.run_pretrain_file = "llm/run_pretrain.py" | ||
|
||
def runfirst(self, train_args): | ||
train_args["unified_checkpoint"] = 0 | ||
|
@@ -701,7 +701,7 @@ def setUp(self): | |
self.need_allclose = True | ||
self.rtol = 1e-7 | ||
|
||
self.run_pretrain_file = "llm/llama/run_pretrain.py" | ||
self.run_pretrain_file = "llm/run_pretrain.py" | ||
self.filelists = [ | ||
"config.json", | ||
"master_weights-00001-of-00002.safetensors", | ||
|
@@ -1132,24 +1132,6 @@ def rerun(self, train_args): | |
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol) | ||
|
||
|
||
@pytest.mark.skipif(True, reason="Skip for None CE") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥删了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 日后如果增加了ignore_merge_optimizer的选项,会和skip_save_model_weight产生冲突,所以删掉了。 |
||
class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase): | ||
def setUp(self): | ||
super().setUp() | ||
for config_key in self.configs: | ||
self.configs[config_key]["unified_checkpoint"] = 1 | ||
self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options" | ||
|
||
self.need_allclose = True | ||
self.rtol = 1e-7 | ||
|
||
def runfirst(self, train_args): | ||
self.run_n1c8(self.run_pretrain_file, **train_args) | ||
|
||
def rerun(self, train_args): | ||
self.run_n1c8(self.run_pretrain_file, **train_args) | ||
|
||
|
||
@pytest.mark.skipif(True, reason="Skip for None CE") | ||
class TestUnifiedCheckpointOnN1C8SaveLoadSpeed(TestUnifiedCheckpointFull): | ||
def setUp(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我在想,是不是可以 数据集那里,自己去构造 Fake的 dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不太理解什么意思,现在这么写我感觉没啥问题?