Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph] Support sharding stage2/3+dp in GPT-3 model #2471

Merged
merged 6 commits into from
Sep 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions examples/language_model/gpt-3/dygraph/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import DygraphShardingOptimizer
from paddle.fluid.dygraph.parallel import sync_params_buffers
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients

# add sharding stage2/3
from paddle.distributed.sharding import group_sharded_parallel
Expand Down Expand Up @@ -151,9 +153,10 @@ def do_train(args):
dp_rank = hcg.get_data_parallel_rank()
sharding_rank = hcg.get_sharding_parallel_rank()

# sharding stage2/3 not support hybrid parallel
# sharding stage2/3 not support hybrid parallel now
if args.sharding_stage in [2, 3]:
assert args.dp_degree == args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support hybrid parallel later"
assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later"
dp_group = hcg.get_data_parallel_group()

sharding_size = hcg.get_sharding_parallel_world_size()
data_world_rank = dp_rank * sharding_size + sharding_rank
Expand Down Expand Up @@ -275,6 +278,11 @@ def do_train(args):
# wrap sharding stage2/3 and add collective group
# TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
if args.sharding_stage in [2, 3]:
if args.dp_degree > 1:
sync_params_buffers(model,
comm_group=dp_group,
src_rank=dp_group.ranks[0])

scaler = scaler if args.use_pure_fp16 else None
model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
args.sharding_offload)
Expand Down Expand Up @@ -359,6 +367,16 @@ def do_train(args):
loss_mbs.backward()
loss = loss + loss_mbs

if args.sharding_stage in [2, 3] and args.dp_degree > 1:
fused_allreduce_gradients(model.parameters(), hcg)
Comment on lines +370 to +371
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里fused_allreduce_gradient和下面都加一下注释吧。不清楚是做了啥通信

if args.sharding_stage == 3:
for p in model.parameters():
if hasattr(p, "bw_storage"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 bw_storage 是?

assert p.grad is None, "This case shouldn't happen."
p.bw_storage.scale_(1.0 / dp_group.nranks)
paddle.distributed.all_reduce(
p.bw_storage, group=dp_group)

if args.use_pure_fp16:
if args.sharding_stage in [2, 3]:
scaler.step(optimizer)
Expand Down