Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hybrid Parallel] Add model parallel support in dygraph #32248

Merged
merged 19 commits into from
Apr 16, 2021

Conversation

ForFishes
Copy link
Member

@ForFishes ForFishes commented Apr 13, 2021

PR types

New features

PR changes

APIs

Describe

[Hybrid Parallel] Add model parallel support in dygraph,

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

message HybridConfig {
optional int32 num_data_parallel = 1 [ default = -1 ];
optional int32 num_model_parallel = 2 [ default = 1 ];
optional int32 num_pipeline_parallel = 3 [ default = 1 ];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use dp_degree, mp_degree and pp_degree to be consistent with static graph.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -21,6 +21,7 @@
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator
from . import metrics
from .base.topology import CommunicateTopology, HybridCommunicateGroup
from .mpu import random, layers

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why import these two files?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm it

"num_data_parallel": 1,
"num_model_parallel": 2,
"num_pipeline_parallel": 1}
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use dp_degree, mp_degree and pp_degree

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

check_configs_key(self.strategy.model_parallel_configs, configs,
"model_parallel_configs")
assign_configs_value(self.strategy.model_parallel_configs, configs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete unnecessary model_parallel_configs in this pr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, fix it


self._topology = tp.CommunicateTopology(
hybrid_group_names=["data", "model", "pipe"],
dims=[self.dp_num, self.mp_num, self.pp_num])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ['data', 'pipe', 'model'] instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

per_part_size += 1 # make the last row as the padding index
self.per_part_size = per_part_size

with get_rng_state_tracker().rng_state():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete unused with

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm it

self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()

with get_rng_state_tracker().rng_state():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it


self.input_size_per_partition = in_features // self.world_size

with get_rng_state_tracker().rng_state():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

XieYunshen
XieYunshen previously approved these changes Apr 16, 2021
sandyhouse
sandyhouse previously approved these changes Apr 16, 2021
Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes dismissed stale reviews from sandyhouse and XieYunshen via 99e0099 April 16, 2021 09:40
Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit 66d4622 into PaddlePaddle:develop Apr 16, 2021
@ForFishes ForFishes deleted the add_mp_layer branch April 16, 2021 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants