Skip to content

Commit

Permalink
Update on "compiled RMSNorm"
Browse files Browse the repository at this point in the history
On Llama3 8B model, no AC
`compiled_rmsnorm` is ~9% faster than `rmsnorm`, but ~2% slower than `fused_rmsnorm`.
Please see below for details.

rmsnorm
<img width="757" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/79645518-e38b-4ddb-b01d-b0c93ec27dd4">

compiled_rmsnorm
<img width="754" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/c457b388-793f-452b-9bce-17bc1823df66">

fused_rmsnorm
<img width="753" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/ea1db7ad-5887-4efa-9788-e708e4b40428">



[ghstack-poisoned]
  • Loading branch information
tianyu-l committed Jul 10, 2024
1 parent 4c33e52 commit 2071dfb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def estimate_memory(job_config: JobConfig):
)
job_config.model.norm_type = "rmsnorm"

if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
Expand Down
2 changes: 1 addition & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--memory_estimation.enabled",
"--memory_estimation.enabled --model.norm_type rmsnorm",
]
],
"FSDP2 Memory Tracking and Estimation",
Expand Down

0 comments on commit 2071dfb

Please sign in to comment.