From 23cad7be37f83e266059937535146b5eaac6d739 Mon Sep 17 00:00:00 2001 From: Taekyung Heo Date: Wed, 20 Nov 2024 13:00:23 -0500 Subject: [PATCH] Update Chakra replay to use tdef --- .../slurm_command_gen_strategy.py | 14 +++++--- ...hakra_replay_slurm_command_gen_strategy.py | 32 ++++++++----------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py index aba0430d..e6534343 100644 --- a/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py @@ -14,10 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict, List, cast from cloudai import TestRun from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy +from cloudai.test_definitions.chakra_replay import ChakraReplayTestDefinition class ChakraReplaySlurmCommandGenStrategy(SlurmCommandGenStrategy): @@ -28,10 +29,13 @@ def _parse_slurm_args( ) -> Dict[str, Any]: base_args = super()._parse_slurm_args(job_name_prefix, env_vars, cmd_args, tr) - image_path = cmd_args["docker_image_url"] - container_mounts = f"{cmd_args['trace_path']}:{cmd_args['trace_path']}" - - base_args.update({"image_path": image_path, "container_mounts": container_mounts}) + tdef: ChakraReplayTestDefinition = cast(ChakraReplayTestDefinition, tr.test.test_definition) + base_args.update( + { + "image_path": tdef.docker_image.installed_path, + "container_mounts": f"{tdef.cmd_args.trace_path}:{tdef.cmd_args.trace_path}", + } + ) return base_args diff --git a/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py index bade5762..15dd1ddb 100644 --- a/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py @@ -31,7 +31,7 @@ def cmd_gen_strategy(self, slurm_system: SlurmSystem) -> ChakraReplaySlurmComman return ChakraReplaySlurmCommandGenStrategy(slurm_system, {}) @pytest.mark.parametrize( - "job_name_prefix, env_vars, cmd_args, num_nodes, nodes, expected_result", + "job_name_prefix, env_vars, cmd_args_attrs, num_nodes, nodes, expected_result", [ ( "chakra_replay", @@ -62,30 +62,24 @@ def test_parse_slurm_args( cmd_gen_strategy: ChakraReplaySlurmCommandGenStrategy, job_name_prefix: str, env_vars: Dict[str, str], - cmd_args: Dict[str, str], + cmd_args_attrs: Dict[str, Any], num_nodes: int, nodes: List[str], expected_result: Dict[str, Any], ) -> None: - tr = Mock(spec=TestRun) - tr.num_nodes = num_nodes - tr.nodes = nodes - slurm_args = cmd_gen_strategy._parse_slurm_args(job_name_prefix, env_vars, cmd_args, tr) - assert slurm_args["image_path"] == expected_result["image_path"] - assert slurm_args["container_mounts"] == expected_result["container_mounts"] + mock_cmd_args = Mock(**cmd_args_attrs) + mock_docker_image = Mock(installed_path=cmd_args_attrs["docker_image_url"]) + mock_test_definition = Mock(cmd_args=mock_cmd_args, docker_image=mock_docker_image) - def test_parse_slurm_args_invalid_cmd_args(self, cmd_gen_strategy: ChakraReplaySlurmCommandGenStrategy) -> None: - job_name_prefix = "chakra_replay" - env_vars = {"NCCL_DEBUG": "INFO"} - cmd_args = {"trace_path": "/workspace/traces/"} # Missing "docker_image_url" - tr = Mock(spec=TestRun) - tr.num_nodes = 2 - tr.nodes = ["node1", "node2"] + mock_test_run = Mock( + num_nodes=num_nodes, + nodes=nodes, + test=Mock(test_definition=mock_test_definition), + ) - with pytest.raises(KeyError) as exc_info: - cmd_gen_strategy._parse_slurm_args(job_name_prefix, env_vars, cmd_args, tr) - - assert str(exc_info.value) == "'docker_image_url'", "Expected missing docker_image_url key" + slurm_args = cmd_gen_strategy._parse_slurm_args(job_name_prefix, env_vars, {}, mock_test_run) + assert slurm_args["image_path"] == expected_result["image_path"] + assert slurm_args["container_mounts"] == expected_result["container_mounts"] @pytest.mark.parametrize( "cmd_args, extra_cmd_args, expected_result",