Skip to content

Commit

Permalink
Update on "Renamed parallel styles for transformer block weights"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
awgu committed Jul 10, 2024
2 parents 73265be + c52d0a6 commit 06e30d0
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

tp_mesh = world_mesh["tp"]
# Parallel styles for transformer block linear weights may be different for
# float8 linears
# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
(
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_module_weight_input,
prepare_module_input,
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

Expand Down Expand Up @@ -353,7 +353,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_weight_input(
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
Expand All @@ -362,7 +362,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wv": colwise_parallel_weight(),
"attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_weight_input(
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
Expand Down

0 comments on commit 06e30d0

Please sign in to comment.