Skip to content

Commit

Permalink
[Cherry-pick] Add config for broadcast data in model parallel (#57843)
Browse files Browse the repository at this point in the history
* add config for broadcast data for mp

* fix mp bug (#58037)

---------

Co-authored-by: Yuang Liu <liuyuang@baidu.com>
  • Loading branch information
ForFishes and FeixLiu authored Oct 17, 2023
1 parent ff9d102 commit 1be4c8d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ message MpConfig {
optional bool sync_grad= 2 [ default = false ];
optional bool sync_moment= 3 [ default = false ];
optional string sync_mode= 4 [ default = 'broadcast' ];
// Broadcast mp input data
optional bool need_broadcast_data=8 [default = true];
}

message PpConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,10 @@ def _prepare_for_model(self):
logger.info("mp's parameters is ready")

def _pre_forward(self, *inputs, **kwargs):
logger.debug("mp start broadcast input data")
return broadcast_input_data(self._hcg, *inputs, **kwargs)
need_broadcast_data = True
if self._strategy is not None:
mp_configs = self._strategy.hybrid_configs["mp_configs"]
need_broadcast_data = mp_configs.need_broadcast_data
if need_broadcast_data:
logger.debug("mp start broadcast input data")
return broadcast_input_data(self._hcg, *inputs, **kwargs)

0 comments on commit 1be4c8d

Please sign in to comment.