From 2071dfb57035dec601e7108c6834ea9e8d4d5c96 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 10 Jul 2024 00:07:13 -0700 Subject: [PATCH] Update on "compiled RMSNorm" On Llama3 8B model, no AC `compiled_rmsnorm` is ~9% faster than `rmsnorm`, but ~2% slower than `fused_rmsnorm`. Please see below for details. rmsnorm image compiled_rmsnorm image fused_rmsnorm image [ghstack-poisoned] --- estimation.py | 6 +++++- test_runner.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/estimation.py b/estimation.py index e82a7b71..ddf24d8a 100644 --- a/estimation.py +++ b/estimation.py @@ -57,8 +57,12 @@ def estimate_memory(job_config: JobConfig): ) job_config.model.norm_type = "rmsnorm" + if job_config.model.norm_type == "compiled_rmsnorm": + logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.") + job_config.model.norm_type = "rmsnorm" + if job_config.training.compile: - logger.info("Compile mode is not supported yet. " "Switching to Eager mode.") + logger.info("Compile mode is not supported yet. Switching to eager mode.") job_config.training.compile = False parallel_dims = ParallelDims( diff --git a/test_runner.py b/test_runner.py index cba63544..319f99d7 100755 --- a/test_runner.py +++ b/test_runner.py @@ -266,7 +266,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--memory_estimation.enabled", + "--memory_estimation.enabled --model.norm_type rmsnorm", ] ], "FSDP2 Memory Tracking and Estimation",