From 2c7203fdc9a8b20cc965f34889410ed286647c51 Mon Sep 17 00:00:00 2001 From: Joris Henstra Date: Sun, 11 Feb 2024 16:54:08 +0100 Subject: [PATCH 1/2] Fix incorrect mapping update for routing env from grid size. Relabel the nodes to be indexed when creating connection graph from grid size. --- qgym/utils/input_parsing.py | 8 +++++-- qgym/utils/input_validation.py | 7 +++++-- tests/envs/routing/test_routing_env.py | 13 +++++++++--- tests/utils/test_input_validation.py | 29 +++++++++++++++++--------- 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/qgym/utils/input_parsing.py b/qgym/utils/input_parsing.py index 4434b48b..39c7493e 100644 --- a/qgym/utils/input_parsing.py +++ b/qgym/utils/input_parsing.py @@ -1,7 +1,7 @@ """This module contains function which parse user input. With parsing we mean that the user input is validated and transformed to a predictable -format. In this way, user can give different input formats, but internally we are +format. In this way, user can give different input formats, but internally we are assured that the data has the same format.""" from __future__ import annotations @@ -105,7 +105,11 @@ def parse_connection_graph( if grid_size is not None: # Generate connection grid graph - return nx.grid_graph(grid_size) + graph = nx.grid_graph(grid_size) + + # Relabel the nodes to be indexed + mapping = {node: i for i, node in enumerate(graph.nodes())} + return nx.relabel_nodes(graph, mapping) raise ValueError("No valid arguments for a connection graph were given") diff --git a/qgym/utils/input_validation.py b/qgym/utils/input_validation.py index b753fd9f..3d6229fb 100644 --- a/qgym/utils/input_validation.py +++ b/qgym/utils/input_validation.py @@ -206,8 +206,8 @@ def check_adjacency_matrix(adjacency_matrix: ArrayLike) -> NDArray[Any]: def check_graph_is_valid_topology(graph: nx.Graph, name: str) -> None: - """Check if `graph` with name 'name' is an instance of ``networkx.Graph`` and check - if the graph is valid topology graph. + """Check if `graph` with name 'name' is an instance of ``networkx.Graph``, check + if the graph is valid topology graph and check if the nodes are integers. Args: graph: Graph to check. @@ -226,6 +226,9 @@ def check_graph_is_valid_topology(graph: nx.Graph, name: str) -> None: if len(graph) == 0: raise ValueError(f"'{name}' has no nodes") + if not all(isinstance(node, int) for node in graph.nodes()): + raise TypeError(f"'{name}' has nodes that are not integers") + def check_instance(x: Any, name: str, dtype: type) -> None: """Check if `x` with name 'name' is an instance of dtype. diff --git a/tests/envs/routing/test_routing_env.py b/tests/envs/routing/test_routing_env.py index 20f8704c..c27e8866 100644 --- a/tests/envs/routing/test_routing_env.py +++ b/tests/envs/routing/test_routing_env.py @@ -1,5 +1,6 @@ from __future__ import annotations +import numpy as np import pytest from stable_baselines3.common.env_checker import check_env @@ -19,6 +20,12 @@ }, ], ) -def test_validity(kwargs: dict[str, tuple[int, int] | bool]) -> None: - env = Routing(**kwargs) # type: ignore[arg-type] - check_env(env, warn=True) # todo: maybe switch this to the gym env checker +class TestEnvironment: + def test_validity(self, kwargs: dict[str, tuple[int, int] | bool]) -> None: + env = Routing(**kwargs) # type: ignore[arg-type] + check_env(env, warn=True) # todo: maybe switch this to the gym env checker + + def test_step(self, kwargs): + env = Routing(**kwargs) # type: ignore[arg-type] + obs = env.step(0)[0] + assert np.array_equal(obs["mapping"], [2, 1, 0, 3]) diff --git a/tests/utils/test_input_validation.py b/tests/utils/test_input_validation.py index b09c2547..d5dec165 100644 --- a/tests/utils/test_input_validation.py +++ b/tests/utils/test_input_validation.py @@ -153,19 +153,28 @@ def test_check_adjacency_matrix_errors(self, arg: Any) -> None: check_adjacency_matrix(arg) -def test_check_graph_is_valid_topology() -> None: - graph = nx.Graph() - msg = "'test' has no nodes" - with pytest.raises(ValueError, match=msg): +class TestGraphValidTopology: + def test_check_graph_is_valid_topology(self) -> None: + graph = nx.Graph() + msg = "'test' has no nodes" + with pytest.raises(ValueError, match=msg): + check_graph_is_valid_topology(graph, "test") + + graph.add_edge(1, 2) check_graph_is_valid_topology(graph, "test") - graph.add_edge(1, 2) - check_graph_is_valid_topology(graph, "test") + graph.add_edge(1, 1) + msg = "'test' contains self-loops" + with pytest.raises(ValueError, match=msg): + check_graph_is_valid_topology(graph, "test") + + def test_check_graph_is_valid_topology_nodes(self) -> None: + graph = nx.Graph() + graph.add_node((0, 0)) - graph.add_edge(1, 1) - msg = "'test' contains self-loops" - with pytest.raises(ValueError, match=msg): - check_graph_is_valid_topology(graph, "test") + msg = "'test' has nodes that are not integers" + with pytest.raises(TypeError, match=msg): + check_graph_is_valid_topology(graph, "test") class TestCheckInstance: From 38305a745fe6676a091194401adb835a944a3592 Mon Sep 17 00:00:00 2001 From: jhenstra <74776892+jhenstra@users.noreply.github.com> Date: Wed, 14 Feb 2024 17:39:02 +0100 Subject: [PATCH 2/2] Use networkx builtin function to convert node labels to integers. --- qgym/utils/input_parsing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qgym/utils/input_parsing.py b/qgym/utils/input_parsing.py index 39c7493e..17823cd2 100644 --- a/qgym/utils/input_parsing.py +++ b/qgym/utils/input_parsing.py @@ -107,9 +107,9 @@ def parse_connection_graph( # Generate connection grid graph graph = nx.grid_graph(grid_size) - # Relabel the nodes to be indexed - mapping = {node: i for i, node in enumerate(graph.nodes())} - return nx.relabel_nodes(graph, mapping) + # Relabel the nodes to be integers + graph = nx.convert_node_labels_to_integers(graph) + return graph raise ValueError("No valid arguments for a connection graph were given")