From 6165d3d7e5e4fcd35ae7592580726b3322fdd1ae Mon Sep 17 00:00:00 2001
From: Andrew Gu <andgu@fb.com>
Date: Wed, 10 Jul 2024 07:53:25 -0700
Subject: [PATCH 1/2] Reordered TP parallel plan to follow execution order

[ghstack-poisoned]
---
 torchtitan/parallelisms/parallelize_llama.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index 7becb731..c07d4c33 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -332,7 +332,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
     """
     Apply tensor parallelism.
     """
-
     tp_mesh = world_mesh["tp"]
     (
         row_parallel_strategy,
@@ -341,9 +340,10 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
     ) = get_tp_parallel_strategy(job_config)
     loss_parallel = parallel_dims.loss_parallel_enabled
 
-    # 1. Parallelize the first embedding and the last linear proj layer
+    # 1. Parallelize the embedding and shard its outputs (which are the first
+    # transformer block's inputs)
     # 2. Parallelize the root norm layer over the sequence dim
-    # 3. Shard the first transformer block's inputs
+    # 3. Parallelize the final linear output layer
     model = parallelize_module(
         model,
         tp_mesh,
@@ -352,12 +352,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
                 input_layouts=Replicate(),
                 output_layouts=Shard(1),
             ),
+            "norm": SequenceParallel(),
             "output": col_parallel_strategy(
                 input_layouts=Shard(1),
                 output_layouts=Shard(-1) if loss_parallel else Replicate(),
                 use_local_output=not loss_parallel,
             ),
-            "norm": SequenceParallel(),
         },
     )
 
@@ -367,6 +367,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
     #       Examples can be found at https://github.com/pytorch/torchtitan/pull/437
     for layer_id, transformer_block in model.layers.items():
         layer_plan = {
+            "attention_norm": SequenceParallel(),
             "attention": prepare_module_input(
                 input_layouts=(Shard(1), None),
                 desired_input_layouts=(Replicate(), None),
@@ -375,7 +376,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
             "attention.wk": col_parallel_strategy(),
             "attention.wv": col_parallel_strategy(),
             "attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
-            "attention_norm": SequenceParallel(),
+            "ffn_norm": SequenceParallel(),
             "feed_forward": prepare_module_input(
                 input_layouts=(Shard(1),),
                 desired_input_layouts=(Replicate(),),
@@ -383,7 +384,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
             "feed_forward.w1": col_parallel_strategy(),
             "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
             "feed_forward.w3": col_parallel_strategy(),
-            "ffn_norm": SequenceParallel(),
         }
 
         # Adjust attention module to use the local number of heads

From 74304ba949fe260f9160380762c3b8f5d4fc6e0a Mon Sep 17 00:00:00 2001
From: Andrew Gu <andgu@fb.com>
Date: Wed, 10 Jul 2024 08:05:38 -0700
Subject: [PATCH 2/2] Update on "Reordered TP parallel plan to follow execution
 order"

- Llama uses pre-norm (norm before attention and before FFN), so we can move these up.
- The root norm is before output, so we can swap this order too.



[ghstack-poisoned]
---
 torchtitan/parallelisms/parallelize_llama.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index c07d4c33..32fbcc63 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -332,6 +332,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
     """
     Apply tensor parallelism.
     """
+
     tp_mesh = world_mesh["tp"]
     (
         row_parallel_strategy,