diff --git a/examples/language_model/gpt-3/dygraph/run_pretrain.py b/examples/language_model/gpt-3/dygraph/run_pretrain.py index d45250272150..bf8a29342a55 100644 --- a/examples/language_model/gpt-3/dygraph/run_pretrain.py +++ b/examples/language_model/gpt-3/dygraph/run_pretrain.py @@ -37,6 +37,8 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer +from paddle.fluid.dygraph.parallel import sync_params_buffers +from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients # add sharding stage2/3 from paddle.distributed.sharding import group_sharded_parallel @@ -151,9 +153,10 @@ def do_train(args): dp_rank = hcg.get_data_parallel_rank() sharding_rank = hcg.get_sharding_parallel_rank() - # sharding stage2/3 not support hybrid parallel + # sharding stage2/3 not support hybrid parallel now if args.sharding_stage in [2, 3]: - assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later" + assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later" + dp_group = hcg.get_data_parallel_group() sharding_size = hcg.get_sharding_parallel_world_size() data_world_rank = dp_rank * sharding_size + sharding_rank @@ -275,6 +278,11 @@ def do_train(args): # wrap sharding stage2/3 and add collective group # TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature if args.sharding_stage in [2, 3]: + if args.dp_degree > 1: + sync_params_buffers(model, + comm_group=dp_group, + src_rank=dp_group.ranks[0]) + scaler = scaler if args.use_pure_fp16 else None model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler, args.sharding_offload) @@ -359,6 +367,16 @@ def do_train(args): loss_mbs.backward() loss = loss + loss_mbs + if args.sharding_stage in [2, 3] and args.dp_degree > 1: + fused_allreduce_gradients(model.parameters(), hcg) + if args.sharding_stage == 3: + for p in model.parameters(): + if hasattr(p, "bw_storage"): + assert p.grad is None, "This case shouldn't happen." + p.bw_storage.scale_(1.0 / dp_group.nranks) + paddle.distributed.all_reduce( + p.bw_storage, group=dp_group) + if args.use_pure_fp16: if args.sharding_stage in [2, 3]: scaler.step(optimizer)