From d13db15730bc3e3440d98017bdcd069bcfc02ab3 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Thu, 9 Jun 2022 08:58:52 +0000 Subject: [PATCH 1/3] add sharding+dp --- .../language_model/gpt-3/dygraph/run_pretrain.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/language_model/gpt-3/dygraph/run_pretrain.py b/examples/language_model/gpt-3/dygraph/run_pretrain.py index 7d245b938db5..050e739127d1 100644 --- a/examples/language_model/gpt-3/dygraph/run_pretrain.py +++ b/examples/language_model/gpt-3/dygraph/run_pretrain.py @@ -37,6 +37,8 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer +from paddle.fluid.dygraph.parallel import sync_params_buffers +from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients # add sharding stage2/3 from paddle.distributed.sharding import group_sharded_parallel @@ -152,7 +154,9 @@ def do_train(args): # sharding stage2/3 not support hybrid parallel if args.sharding_stage in [2, 3]: - assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later" + assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later" + dp_group = hcg.get_data_parallel_group() + sharding_group = hcg.get_sharding_parallel_group() sharding_size = hcg.get_sharding_parallel_world_size() data_world_rank = dp_rank * sharding_size + sharding_rank @@ -275,6 +279,9 @@ def do_train(args): scaler = scaler if args.use_pure_fp16 else None model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler, args.sharding_offload) + if args.dp_degree > 1: + sync_params_buffers( + model, comm_group=dp_group, src_rank=dp_group.ranks[0]) elif paddle.distributed.get_world_size() > 1: model = fleet.distributed_model(model) @@ -353,6 +360,9 @@ def do_train(args): loss_mbs.backward() loss = loss + loss_mbs + if args.sharding_stage in [2, 3] and args.dp_degree > 1: + fused_allreduce_gradients(model.parameters(), hcg) + if args.use_pure_fp16: if args.sharding_stage in [2, 3]: scaler.step(optimizer) From 5793563e874c265965c046d99eccdc718cec1cd5 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Fri, 10 Jun 2022 01:38:24 +0000 Subject: [PATCH 2/3] update --- .../language_model/gpt-3/dygraph/run_pretrain.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/language_model/gpt-3/dygraph/run_pretrain.py b/examples/language_model/gpt-3/dygraph/run_pretrain.py index 050e739127d1..c6f4dd623b0d 100644 --- a/examples/language_model/gpt-3/dygraph/run_pretrain.py +++ b/examples/language_model/gpt-3/dygraph/run_pretrain.py @@ -156,7 +156,6 @@ def do_train(args): if args.sharding_stage in [2, 3]: assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later" dp_group = hcg.get_data_parallel_group() - sharding_group = hcg.get_sharding_parallel_group() sharding_size = hcg.get_sharding_parallel_world_size() data_world_rank = dp_rank * sharding_size + sharding_rank @@ -276,13 +275,14 @@ def do_train(args): # wrap sharding stage2/3 and add collective group # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature if args.sharding_stage in [2, 3]: - scaler = scaler if args.use_pure_fp16 else None - model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler, - args.sharding_offload) if args.dp_degree > 1: sync_params_buffers( model, comm_group=dp_group, src_rank=dp_group.ranks[0]) + scaler = scaler if args.use_pure_fp16 else None + model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler, + args.sharding_offload) + elif paddle.distributed.get_world_size() > 1: model = fleet.distributed_model(model) optimizer = fleet.distributed_optimizer(optimizer) @@ -362,6 +362,13 @@ def do_train(args): if args.sharding_stage in [2, 3] and args.dp_degree > 1: fused_allreduce_gradients(model.parameters(), hcg) + if args.sharding_stage == 3: + for p in model.parameters(): + if hasattr(p, "bw_storage"): + assert p.grad is None, "This case shouldn't happen." + p.bw_storage.scale_(1.0 / dp_group.nranks) + paddle.distributed.all_reduce( + p.bw_storage, group=dp_group) if args.use_pure_fp16: if args.sharding_stage in [2, 3]: From 92910b8361a6ca90b10215b047823aedf9a9b500 Mon Sep 17 00:00:00 2001 From: haohongxiang Date: Wed, 28 Sep 2022 04:54:37 +0000 Subject: [PATCH 3/3] code style check --- examples/language_model/gpt-3/dygraph/run_pretrain.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/language_model/gpt-3/dygraph/run_pretrain.py b/examples/language_model/gpt-3/dygraph/run_pretrain.py index bab42ce039a5..bf8a29342a55 100644 --- a/examples/language_model/gpt-3/dygraph/run_pretrain.py +++ b/examples/language_model/gpt-3/dygraph/run_pretrain.py @@ -153,7 +153,7 @@ def do_train(args): dp_rank = hcg.get_data_parallel_rank() sharding_rank = hcg.get_sharding_parallel_rank() - # sharding stage2/3 not support hybrid parallel + # sharding stage2/3 not support hybrid parallel now if args.sharding_stage in [2, 3]: assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later" dp_group = hcg.get_data_parallel_group() @@ -279,8 +279,9 @@ def do_train(args): # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature if args.sharding_stage in [2, 3]: if args.dp_degree > 1: - sync_params_buffers( - model, comm_group=dp_group, src_rank=dp_group.ranks[0]) + sync_params_buffers(model, + comm_group=dp_group, + src_rank=dp_group.ranks[0]) scaler = scaler if args.use_pure_fp16 else None model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,