diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index 7becb731b..32fbcc633 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -341,9 +341,10 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
     ) = get_tp_parallel_strategy(job_config)
     loss_parallel = parallel_dims.loss_parallel_enabled
 
-    # 1. Parallelize the first embedding and the last linear proj layer
+    # 1. Parallelize the embedding and shard its outputs (which are the first
+    # transformer block's inputs)
     # 2. Parallelize the root norm layer over the sequence dim
-    # 3. Shard the first transformer block's inputs
+    # 3. Parallelize the final linear output layer
     model = parallelize_module(
         model,
         tp_mesh,
@@ -352,12 +353,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
                 input_layouts=Replicate(),
                 output_layouts=Shard(1),
             ),
+            "norm": SequenceParallel(),
             "output": col_parallel_strategy(
                 input_layouts=Shard(1),
                 output_layouts=Shard(-1) if loss_parallel else Replicate(),
                 use_local_output=not loss_parallel,
             ),
-            "norm": SequenceParallel(),
         },
     )
 
@@ -367,6 +368,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
     #       Examples can be found at https://github.com/pytorch/torchtitan/pull/437
     for layer_id, transformer_block in model.layers.items():
         layer_plan = {
+            "attention_norm": SequenceParallel(),
             "attention": prepare_module_input(
                 input_layouts=(Shard(1), None),
                 desired_input_layouts=(Replicate(), None),
@@ -375,7 +377,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
             "attention.wk": col_parallel_strategy(),
             "attention.wv": col_parallel_strategy(),
             "attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
-            "attention_norm": SequenceParallel(),
+            "ffn_norm": SequenceParallel(),
             "feed_forward": prepare_module_input(
                 input_layouts=(Shard(1),),
                 desired_input_layouts=(Replicate(),),
@@ -383,7 +385,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
             "feed_forward.w1": col_parallel_strategy(),
             "feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
             "feed_forward.w3": col_parallel_strategy(),
-            "ffn_norm": SequenceParallel(),
         }
 
         # Adjust attention module to use the local number of heads