Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Fix distributed dataloader #8932

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions paddlenlp/data/dist_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def __len__(self):
return 0


class IterableDummyDataset(paddle.io.IterableDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在想,是不是可以 数据集那里,自己去构造 Fake的 dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不太理解什么意思,现在这么写我感觉没啥问题?

def __iter__(self):
return None


class DistDataLoader(paddle.io.DataLoader):
"""
DistDataLoader is a wrapper of paddle.io.DataLoader.
Expand All @@ -56,11 +61,14 @@ def __init__(
timeout=0,
worker_init_fn=None,
persistent_workers=False,
eval=False,
**kwargs,
):

eval = kwargs.pop("eval", False)
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)

if dataset is None:
dataset = DummyDataset()
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
logger.info("rank has no data, use Dummpy dataset")

super().__init__(dataset=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=num_workers)
Expand Down Expand Up @@ -200,7 +208,7 @@ def __next__(self):
try:
data = next(self._dataloader_iter)
data = nested_copy_place(data, place=paddle.framework._current_expected_place())
except:
pass
except Exception as e:
logger.debug(e)
data = self._broadcast_data(data)
return data
147 changes: 78 additions & 69 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,12 +1398,16 @@ def get_train_dataloader(self):
raise ValueError("We don't need train_dataset when should_load_dataset is False.")

train_dataset = self.train_dataset
if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset)
else:
is_iterable_dataset = self._is_iterable_dataset(train_dataset)
if is_datasets_available() and train_dataset is not None and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if self._is_iterable_dataset(train_dataset):
if self.args.dataset_world_size > 1:
if is_iterable_dataset: # For iterable dataset
if self.args.dataset_world_size > 1 and train_dataset is not None:
train_dataset = IterableDatasetShard(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
Expand All @@ -1412,24 +1416,28 @@ def get_train_dataloader(self):
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
additional_configs = {"is_iterable_dataset": True}
else:
additional_configs = {}
return _DataLoader(
train_dataset,
batch_size=self.args.per_device_train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)
else:
train_sampler = self._get_train_sampler()
if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")
return _DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)

train_sampler = self._get_train_sampler()

if self.args.distributed_dataloader:
logger.info("Training using DistDataLoader.")

return _DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
)

def _get_eval_sampler(self, eval_dataset: Dataset):
if eval_dataset is None or not has_length(eval_dataset):
Expand Down Expand Up @@ -1476,54 +1484,48 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
raise ValueError("We don't need eval_dataset when should_load_dataset is False.")

eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_distributed(eval_dataset)
else:
is_iterable_dataset = self._is_iterable_dataset(eval_dataset)
if is_datasets_available() and eval_dataset is not None and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if self._is_iterable_dataset(eval_dataset):
if self.args.dataset_world_size > 1:
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and eval_dataset is not None:
eval_dataset = IterableDatasetShard(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
drop_last=self.args.dataloader_drop_last,
num_processes=self.args.dataset_world_size,
process_index=self.args.dataset_rank,
)

if self.args.distributed_dataloader:
return DistDataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
eval=True,
)
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True, "is_iterable_dataset": True}
else:
return DataLoader(
eval_dataset,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=0,
)

eval_sampler = self._get_eval_sampler(eval_dataset)

if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")

return DistDataLoader(
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
batch_size=self.args.per_device_eval_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
eval=True,
num_workers=0,
**additional_configs,
)
else:
return DataLoader(
eval_sampler = self._get_eval_sampler(eval_dataset)
if self.args.distributed_dataloader:
logger.info("Eval using DistDataLoader.")
additional_configs = {"eval": True}
else:
additional_configs = {}
return _DataLoader(
eval_dataset,
batch_sampler=eval_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
**additional_configs,
)

def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
Expand All @@ -1542,11 +1544,16 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
if not self.args.should_load_dataset and test_dataset is not None:
raise ValueError("We don't need test_dataset when should_load_dataset is False.")

if self.args.distributed_dataloader:
is_iterable_dataset = self._is_iterable_dataset_distributed(test_dataset)
else:
is_iterable_dataset = self._is_iterable_dataset(test_dataset)
if is_datasets_available() and test_dataset is not None and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader

if self._is_iterable_dataset(test_dataset):
if self.args.dataset_world_size > 1:
if is_iterable_dataset:
if self.args.dataset_world_size > 1 and test_dataset is not None:
test_dataset = IterableDatasetShard(
test_dataset,
batch_size=self.args.per_device_eval_batch_size,
Expand All @@ -1556,40 +1563,31 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
)

if self.args.distributed_dataloader:
return DistDataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
eval=True,
)
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True, "is_iterable_dataset": True}
else:
return DataLoader(
test_dataset,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator, # _get_collator_with_removed_columns
num_workers=self.args.dataloader_num_workers,
)

test_sampler = self._get_eval_sampler(test_dataset)

if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")

# We use the same batch_size as for eval.
return DistDataLoader(
additional_config = {}
return _DataLoader(
test_dataset,
batch_sampler=test_sampler,
batch_size=self.args.per_device_eval_batch_size * self.world_size,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
eval=True,
num_workers=self.args.dataloader_num_workers,
**additional_config,
)
else:
return DataLoader(
test_sampler = self._get_eval_sampler(test_dataset)
if self.args.distributed_dataloader:
logger.info("Test using DistDataLoader.")
additional_config = {"eval": True}
else:
additional_config = {}
# We use the same batch_size as for eval.
return _DataLoader(
test_dataset,
batch_sampler=test_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
**additional_config,
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
Expand Down Expand Up @@ -1694,6 +1692,8 @@ def _load_rng_state(self, checkpoint):

if self.args.use_hybrid_parallel:
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state:
if self.args.tensor_parallel_degree <= 1:
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里触发hang住的原因,请文字说明清楚

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行不会触发hang住,只是修了bug。如果非tp但是rng_state里面有tp的种子,加载起来会报错。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在PR描述里解释了

fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"]
)
Expand Down Expand Up @@ -3201,6 +3201,15 @@ def _get_collator_with_removed_columns(
def _is_iterable_dataset(self, dataset):
return isinstance(dataset, paddle.io.IterableDataset)

def _is_iterable_dataset_distributed(self, dataset):
# For distributed dataloaer.
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1])
if dist.get_world_size() > 1:
dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NPU不支持bool类型通信,需要兼容

if is_iterable_dataset_tensor.item() == 1:
return True
return False

def print_config(self, args=None, key=""):
"""
print config values
Expand Down
22 changes: 2 additions & 20 deletions tests/trainer/test_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

self.run_pretrain_file = "llm/llama/run_pretrain.py"
self.run_pretrain_file = "llm/run_pretrain.py"

def runfirst(self, train_args):
train_args["unified_checkpoint"] = 0
Expand Down Expand Up @@ -701,7 +701,7 @@ def setUp(self):
self.need_allclose = True
self.rtol = 1e-7

self.run_pretrain_file = "llm/llama/run_pretrain.py"
self.run_pretrain_file = "llm/run_pretrain.py"
self.filelists = [
"config.json",
"master_weights-00001-of-00002.safetensors",
Expand Down Expand Up @@ -1132,24 +1132,6 @@ def rerun(self, train_args):
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol)


@pytest.mark.skipif(True, reason="Skip for None CE")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥删了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

日后如果增加了ignore_merge_optimizer的选项,会和skip_save_model_weight产生冲突,所以删掉了。

class TestUnifiedCheckpointOnN1C8EnableAll(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
for config_key in self.configs:
self.configs[config_key]["unified_checkpoint"] = 1
self.configs[config_key]["unified_checkpoint_config"] = "enable_all_options"

self.need_allclose = True
self.rtol = 1e-7

def runfirst(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)


@pytest.mark.skipif(True, reason="Skip for None CE")
class TestUnifiedCheckpointOnN1C8SaveLoadSpeed(TestUnifiedCheckpointFull):
def setUp(self):
Expand Down
Loading