-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Trainer] Fix distributed dataloader #8932
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8932 +/- ##
===========================================
- Coverage 55.05% 55.04% -0.02%
===========================================
Files 635 635
Lines 99412 99449 +37
===========================================
+ Hits 54730 54739 +9
- Misses 44682 44710 +28 ☔ View full report in Codecov by Sentry. |
paddlenlp/trainer/trainer.py
Outdated
train_dataset, | ||
batch_size=self.args.per_device_train_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在普通的DataLoader会触发相关的问题吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会,详见PR描述的卡住原因。
paddlenlp/trainer/trainer.py
Outdated
batch_size=self.args.per_device_train_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
) | ||
|
||
train_sampler = self._get_train_sampler() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面的逻辑是is_iterable_dataset,所以下面是非is_iterable_dataset的代码逻辑?
@@ -1694,6 +1726,8 @@ def _load_rng_state(self, checkpoint): | |||
|
|||
if self.args.use_hybrid_parallel: | |||
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state: | |||
if self.args.tensor_parallel_degree <= 1: | |||
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里触发hang住的原因,请文字说明清楚
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行不会触发hang住,只是修了bug。如果非tp但是rng_state里面有tp的种子,加载起来会报错。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在PR描述里解释了
paddlenlp/trainer/trainer.py
Outdated
@@ -1398,12 +1398,15 @@ def get_train_dataloader(self): | |||
raise ValueError("We don't need train_dataset when should_load_dataset is False.") | |||
|
|||
train_dataset = self.train_dataset | |||
if self.args.distributed_dataloader: | |||
is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_iterable_dataset = self._is_iterable_dataset_dd(train_dataset) | |
is_iterable_dataset = self._is_iterable_dataset_distributed(train_dataset) |
paddlenlp/trainer/trainer.py
Outdated
batch_size=self.args.per_device_train_batch_size, | ||
collate_fn=self.data_collator, | ||
num_workers=self.args.dataloader_num_workers, | ||
is_iterable_dataset=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以使用一个 additional_args = {}
然后 **additional_args
传参。依然保持 DistDataLoader、 DataLoader
合并
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不太行,因为Paddle的DataLoader不支持可变数量参数输入,除非修改Paddle。
@@ -1694,6 +1726,8 @@ def _load_rng_state(self, checkpoint): | |||
|
|||
if self.args.use_hybrid_parallel: | |||
if "hybrid_parallel_rng_state_tracker" in checkpoint_rng_state: | |||
if self.args.tensor_parallel_degree <= 1: | |||
checkpoint_rng_state["hybrid_parallel_rng_state_tracker"].pop("model_parallel_rng", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
@@ -1132,24 +1132,6 @@ def rerun(self, train_args): | |||
np.testing.assert_allclose(res[0], res[-1], rtol=self.rtol) | |||
|
|||
|
|||
@pytest.mark.skipif(True, reason="Skip for None CE") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥删了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
日后如果增加了ignore_merge_optimizer的选项,会和skip_save_model_weight产生冲突,所以删掉了。
@@ -33,6 +33,11 @@ def __len__(self): | |||
return 0 | |||
|
|||
|
|||
class IterableDummyDataset(paddle.io.IterableDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我在想,是不是可以 数据集那里,自己去构造 Fake的 dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不太理解什么意思,现在这么写我感觉没啥问题?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
2384c4d
to
fd9ffba
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要修复
# For distributed dataloaer. | ||
is_iterable_dataset_tensor = paddle.to_tensor(self._is_iterable_dataset(dataset)).reshape([1]) | ||
if dist.get_world_size() > 1: | ||
dist.all_reduce(is_iterable_dataset_tensor, op=dist.ReduceOp.MAX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NPU不支持bool类型通信,需要兼容
* fix ddloader, fix uc unittest * update dataloader
PR types
Bug fixes
PR changes
Others
Description
Distributed dataloader造成hang住的原因:主要针对iterable数据集的热启场景。原来的写法, 数据进程的输入是iterable数据集,从而对应的sampler类型是Infinite类型;而非数据进程的数据输入是None,为None的情况下Paddle的Dataloader会自动设置sampler类型为batch sampler。由于在热启后一般会走跳过数据的逻辑,而跳过数据逻辑主要如下:
因此数据进程会走入第二个分支,而非数据进程会走入第一个分支,从而走入分支逻辑不一致导致卡住,从卡住时的堆栈可以看出具体问题。
数据进程的卡住见下图:
非数据进程的卡住见下图: