Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reordered TP parallel plan to follow execution order #445

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,10 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the first embedding and the last linear proj layer
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
# 3. Parallelize the final linear output layer
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -352,12 +353,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
},
)

Expand All @@ -367,6 +368,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
Expand All @@ -375,15 +377,14 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
Expand Down
Loading