-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AutoParallel] Refine auto_trainer save load #8767
[AutoParallel] Refine auto_trainer save load #8767
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8767 +/- ##
===========================================
- Coverage 55.43% 55.22% -0.21%
===========================================
Files 626 631 +5
Lines 98070 100091 +2021
===========================================
+ Hits 54366 55277 +911
- Misses 43704 44814 +1110 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
self._memory_tracker.start() | ||
|
||
if not self.args.enable_auto_parallel: | ||
if not self.args.should_load_sharding_stage1_model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里先就这样吧,看后面是不是可以抽个函数出来,方便自动并行重载。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
paddlenlp/trainer/auto_trainer.py
Outdated
for p_name, p in model.state_dict().items(): | ||
if paddle.distributed.get_rank() not in p.process_mesh.process_ids: | ||
var_name = p.name | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的一些 _moment1_0
变量,要不放到一个全局变量里面管理一下,不要hard code了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
@@ -216,7 +214,6 @@ def _inner_training_loop( | |||
epochs_trained = self.state.global_step // num_update_steps_per_epoch | |||
if not args.ignore_data_skip: | |||
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) | |||
steps_trained_in_current_epoch *= args.gradient_accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是不需要了,还是之前是错的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前有问题,这里修复了
PR types
Bug fixes
PR changes
Others
Description
Refine save load for auto_trainer.