Skip to content

Commit

Permalink
Update Chakra replay to use tdef
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Nov 20, 2024
1 parent 5c3fd22 commit 23cad7b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from typing import Any, Dict, List, cast

from cloudai import TestRun
from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy
from cloudai.test_definitions.chakra_replay import ChakraReplayTestDefinition


class ChakraReplaySlurmCommandGenStrategy(SlurmCommandGenStrategy):
Expand All @@ -28,10 +29,13 @@ def _parse_slurm_args(
) -> Dict[str, Any]:
base_args = super()._parse_slurm_args(job_name_prefix, env_vars, cmd_args, tr)

image_path = cmd_args["docker_image_url"]
container_mounts = f"{cmd_args['trace_path']}:{cmd_args['trace_path']}"

base_args.update({"image_path": image_path, "container_mounts": container_mounts})
tdef: ChakraReplayTestDefinition = cast(ChakraReplayTestDefinition, tr.test.test_definition)
base_args.update(
{
"image_path": tdef.docker_image.installed_path,
"container_mounts": f"{tdef.cmd_args.trace_path}:{tdef.cmd_args.trace_path}",
}
)

return base_args

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def cmd_gen_strategy(self, slurm_system: SlurmSystem) -> ChakraReplaySlurmComman
return ChakraReplaySlurmCommandGenStrategy(slurm_system, {})

@pytest.mark.parametrize(
"job_name_prefix, env_vars, cmd_args, num_nodes, nodes, expected_result",
"job_name_prefix, env_vars, cmd_args_attrs, num_nodes, nodes, expected_result",
[
(
"chakra_replay",
Expand Down Expand Up @@ -62,30 +62,24 @@ def test_parse_slurm_args(
cmd_gen_strategy: ChakraReplaySlurmCommandGenStrategy,
job_name_prefix: str,
env_vars: Dict[str, str],
cmd_args: Dict[str, str],
cmd_args_attrs: Dict[str, Any],
num_nodes: int,
nodes: List[str],
expected_result: Dict[str, Any],
) -> None:
tr = Mock(spec=TestRun)
tr.num_nodes = num_nodes
tr.nodes = nodes
slurm_args = cmd_gen_strategy._parse_slurm_args(job_name_prefix, env_vars, cmd_args, tr)
assert slurm_args["image_path"] == expected_result["image_path"]
assert slurm_args["container_mounts"] == expected_result["container_mounts"]
mock_cmd_args = Mock(**cmd_args_attrs)
mock_docker_image = Mock(installed_path=cmd_args_attrs["docker_image_url"])
mock_test_definition = Mock(cmd_args=mock_cmd_args, docker_image=mock_docker_image)

def test_parse_slurm_args_invalid_cmd_args(self, cmd_gen_strategy: ChakraReplaySlurmCommandGenStrategy) -> None:
job_name_prefix = "chakra_replay"
env_vars = {"NCCL_DEBUG": "INFO"}
cmd_args = {"trace_path": "/workspace/traces/"} # Missing "docker_image_url"
tr = Mock(spec=TestRun)
tr.num_nodes = 2
tr.nodes = ["node1", "node2"]
mock_test_run = Mock(
num_nodes=num_nodes,
nodes=nodes,
test=Mock(test_definition=mock_test_definition),
)

with pytest.raises(KeyError) as exc_info:
cmd_gen_strategy._parse_slurm_args(job_name_prefix, env_vars, cmd_args, tr)

assert str(exc_info.value) == "'docker_image_url'", "Expected missing docker_image_url key"
slurm_args = cmd_gen_strategy._parse_slurm_args(job_name_prefix, env_vars, {}, mock_test_run)
assert slurm_args["image_path"] == expected_result["image_path"]
assert slurm_args["container_mounts"] == expected_result["container_mounts"]

@pytest.mark.parametrize(
"cmd_args, extra_cmd_args, expected_result",
Expand Down

0 comments on commit 23cad7b

Please sign in to comment.