From 70bbf833688e19e5ec144b06689f46fa9ec1c43f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 28 Aug 2024 16:47:34 -0700 Subject: [PATCH 1/2] remove reset --- tests/tpu/test_compilation.py | 15 +++++++-------- vllm/worker/model_runner.py | 4 ---- vllm/worker/tpu_worker.py | 4 ---- 3 files changed, 7 insertions(+), 16 deletions(-) diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 5a432fb78b3da..69c2306f44ebf 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,6 +5,10 @@ import depyf +# disable custom dispatcher, let Dynamo takes over +# all the control +os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" + temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): cur_dir = os.path.dirname(__file__) @@ -16,19 +20,14 @@ compiled_code = sorted( glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) -full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0] + # we should only trigger Dynamo compilation three times: -# one for the profiling phase (and the compiled artifact will be discarded) +# one for the profiling phase without kv cache # one for the prefill phase with symbolic shapes # one for the decode phase with symbolic shapes # and later calls should not trigger Dynamo compilation again. # NOTE: it might still trigger XLA compilation. # check we have three compiled code +# this is the assumption when we use the custom dispatcher assert len(compiled_code) == 3 - -# check the first compilation is discarded -with open(full_code) as f: - full_code_content = f.read() - profile_function = compiled_code[0].split(".")[0] - assert profile_function not in full_code_content diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b287a5d27157..de1a2e3235a8c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1123,10 +1123,6 @@ def profile_run(self) -> None: device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return def remove_all_loras(self): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 320b15d3604bc..44fa3aed5816d 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -143,10 +143,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = int(self.cache_config.swap_space_bytes // block_size_bytes) num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return num_tpu_blocks, num_cpu_blocks def initialize_cache( From 15b083a2de4b55613ae4253ffcd4f51fbdb0e6b1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 28 Aug 2024 17:13:30 -0700 Subject: [PATCH 2/2] add more tests --- tests/tpu/test_compilation.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 69c2306f44ebf..d8df86b2aaa14 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -31,3 +31,25 @@ # check we have three compiled code # this is the assumption when we use the custom dispatcher assert len(compiled_code) == 3 + +# check all the compilations are as expected +compiled_fn = sorted( + glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) + +# the first compilation is the profiling phase, +# it should not have any kv cache +with open(compiled_fn[0]) as f: + content = f.read() + assert "kv_caches" not in content + +# the second compilation is the prefill phase, +# it should have kv cache and the flash_attention op +with open(compiled_fn[1]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content + +# the third compilation is the decode phase, +# it should have kv cache and the paged_attention op +with open(compiled_fn[2]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content