From 1be4c8de6fa3a583a15e53b68a9634fbc08b3a85 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 17 Oct 2023 11:25:38 +0800 Subject: [PATCH] [Cherry-pick] Add config for broadcast data in model parallel (#57843) * add config for broadcast data for mp * fix mp bug (#58037) --------- Co-authored-by: Yuang Liu --- paddle/fluid/framework/distributed_strategy.proto | 2 ++ .../distributed/fleet/meta_parallel/tensor_parallel.py | 9 +++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 85bafbef2b63e..1af0d447da29c 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -56,6 +56,8 @@ message MpConfig { optional bool sync_grad= 2 [ default = false ]; optional bool sync_moment= 3 [ default = false ]; optional string sync_mode= 4 [ default = 'broadcast' ]; + // Broadcast mp input data + optional bool need_broadcast_data=8 [default = true]; } message PpConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index 883533d8e1724..13546a02b5bd2 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -42,5 +42,10 @@ def _prepare_for_model(self): logger.info("mp's parameters is ready") def _pre_forward(self, *inputs, **kwargs): - logger.debug("mp start broadcast input data") - return broadcast_input_data(self._hcg, *inputs, **kwargs) + need_broadcast_data = True + if self._strategy is not None: + mp_configs = self._strategy.hybrid_configs["mp_configs"] + need_broadcast_data = mp_configs.need_broadcast_data + if need_broadcast_data: + logger.debug("mp start broadcast input data") + return broadcast_input_data(self._hcg, *inputs, **kwargs)