-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hybrid Parallel] Add model parallel support in dygraph #32248
Conversation
Thanks for your contribution! |
message HybridConfig { | ||
optional int32 num_data_parallel = 1 [ default = -1 ]; | ||
optional int32 num_model_parallel = 2 [ default = 1 ]; | ||
optional int32 num_pipeline_parallel = 3 [ default = 1 ]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use dp_degree, mp_degree and pp_degree to be consistent with static graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -21,6 +21,7 @@ | |||
from .data_generator import MultiSlotDataGenerator, MultiSlotStringDataGenerator | |||
from . import metrics | |||
from .base.topology import CommunicateTopology, HybridCommunicateGroup | |||
from .mpu import random, layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why import these two files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm it
"num_data_parallel": 1, | ||
"num_model_parallel": 2, | ||
"num_pipeline_parallel": 1} | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use dp_degree, mp_degree and pp_degree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
check_configs_key(self.strategy.model_parallel_configs, configs, | ||
"model_parallel_configs") | ||
assign_configs_value(self.strategy.model_parallel_configs, configs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete unnecessary model_parallel_configs in this pr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, fix it
|
||
self._topology = tp.CommunicateTopology( | ||
hybrid_group_names=["data", "model", "pipe"], | ||
dims=[self.dp_num, self.mp_num, self.pp_num]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use ['data', 'pipe', 'model'] instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
per_part_size += 1 # make the last row as the padding index | ||
self.per_part_size = per_part_size | ||
|
||
with get_rng_state_tracker().rng_state(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete unused with
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rm it
self._weight_attr = weight_attr | ||
self._dtype = self._helper.get_default_dtype() | ||
|
||
with get_rng_state_tracker().rng_state(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix it
|
||
self.input_size_per_partition = in_features // self.world_size | ||
|
||
with get_rng_state_tracker().rng_state(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
[Hybrid Parallel] Add model parallel support in dygraph,