diff --git a/flatland/envs/agent_chains.py b/flatland/envs/agent_chains.py index 745f0113..ce5deb55 100644 --- a/flatland/envs/agent_chains.py +++ b/flatland/envs/agent_chains.py @@ -1,8 +1,37 @@ +""" +Agent Chains: Unordered Close Following Agents + +Think of a chain of agents, in random order, moving in the same direction. +For any adjacent pair of agents, there's a 0.5 chance that it is in index order, ie index(A) < index(B) where A is in front of B. +So roughly half the adjacent pairs will need to leave a gap and half won't, and the chain of agents will typically be one-third empty space. +By removing the restriction, we can keep the agents close together and +so move up to 50% more agents through a junction or segment of rail in the same number of steps. + +We are still using index order to resolve conflicts between two agents trying to move into the same spot, for example, head-on collisions, or agents "merging" at junctions. + +Implementation: We did it by storing an agent's position as a graph node, and a movement as a directed edge, using the NetworkX graph library. +We create an empty graph for each step, and add the agents into the graph in order, +using their (row, column) location for the node. In this way, agents staying in the same cell (stop action or not at cell exit yet) get a self-loop. +Agents in an adjacent chain naturally get "connected up". + +Pseudocode: +* purple = deadlocked if in deadlock or predecessor of deadlocked (`mark_preds(find_swaps(), 'purple')`) +* red = blocked, i.e. wanting to move, but blocked by an agent ahead not wanting to move or blocked itself (`mark_preds(find_stopped_agents(), 'red')`) +* blue = no agent and >1 wanting to enter, blocking after conflict resolution (`mark_preds(losers, 'red')`) +* magenta: agent (able to move) and >1 wanting to enter, blocking after conflict resolution (`mark_preds(losers, 'red')`) + + +We then use some NetworkX algorithms (https://github.com/networkx/networkx): + * `weakly_connected_components` to find the chains. + * `selfloop_edges` to find the stopped agents + * `dfs_postorder_nodes` to traverse a chain +""" +from typing import Tuple, Set, Dict + import networkx as nx -import numpy as np -from typing import List, Tuple, Set, Union -import graphviz as gv +AgentHandle = int +Cell = Tuple[int, int] class MotionCheck(object): @@ -12,23 +41,16 @@ class MotionCheck(object): """ def __init__(self): - self.G = nx.DiGraph() + self.G = nx.DiGraph() # nodes of type `Cell` + # TODO do we need the reversed graph at all? self.Grev = nx.DiGraph() # reversed graph for finding predecessors self.nDeadlocks = 0 self.svDeadlocked = set() - self._G_reversed: Union[nx.DiGraph, None] = None - def get_G_reversed(self): - #if self._G_reversed is None: - # self._G_reversed = self.G.reverse() - #return self._G_reversed return self.Grev - def reset_G_reversed(self): - self._G_reversed = None - - def addAgent(self, iAg, rc1, rc2, xlabel=None): + def addAgent(self, iAg: AgentHandle, rc1: Cell, rc2: Cell, xlabel=None): """ add an agent and its motion as row,col tuples of current and next position. The agent's current position is given an "agent" attribute recording the agent index. If an agent does not want to move this round (rc1 == rc2) then a self-loop edge is created. @@ -36,7 +58,7 @@ def addAgent(self, iAg, rc1, rc2, xlabel=None): """ # Agents which have not yet entered the env have position None. - # Substitute this for the row = -1, column = agent index + # Substitute this for the row = -1, column = agent index, i.e. they are isolated nodes in the graph! if rc1 is None: rc1 = (-1, iAg) @@ -49,66 +71,42 @@ def addAgent(self, iAg, rc1, rc2, xlabel=None): self.G.add_edge(rc1, rc2) self.Grev.add_edge(rc2, rc1) - def find_stops(self): - """ find all the stopped agents as a set of rc position nodes - A stopped agent is a self-loop on a cell node. + def find_stopped_agents(self) -> Set[Cell]: """ - - # get the (sparse) adjacency matrix - spAdj = nx.linalg.adjacency_matrix(self.G) - - # the stopped agents appear as 1s on the diagonal - # the where turns this into a list of indices of the 1s - giStops = np.where(spAdj.diagonal())[0] - - # convert the cell/node indices into the node rc values - lvAll = list(self.G.nodes()) - # pick out the stops by their indices - lvStops = [lvAll[i] for i in giStops] - # make it into a set ready for a set intersection - svStops = set(lvStops) - return svStops - - def find_stops2(self): - """ alternative method to find stopped agents, using a networkx call to find selfloop edges + Find stopped agents, using a networkx call to find self-loop nodes. + :return: set of stopped agents """ svStops = {u for u, v in nx.classes.function.selfloop_edges(self.G)} return svStops - def find_stop_preds(self, svStops=None): - """ Find the predecessors to a list of stopped agents (ie the nodes / vertices) - Returns the set of predecessors. - Includes "chained" predecessors. + def find_stop_preds(self, svStops: Set[Cell]) -> Set[Cell]: + """ Find the predecessors to a list of stopped agents (ie the nodes / vertices). Includes "chained" predecessors. + :param svStops: list of voluntarily stopped agents + :return: the set of predecessors. """ - if svStops is None: - svStops = self.find_stops2() - # Get all the chains of agents - weakly connected components. # Weakly connected because it's a directed graph and you can traverse a chain of agents # in only one direction + # TODO why do we need weakly connected components at all? Just use reverse traversal of directed edges? lWCC = list(nx.algorithms.components.weakly_connected_components(self.G)) svBlocked = set() - reversed_G = None + reversed_G = self.get_G_reversed() for oWCC in lWCC: if (len(oWCC) == 1): continue - # print("Component:", len(oWCC), oWCC) # Get the node details for this WCC in a subgraph - Gwcc = self.G.subgraph(oWCC) + Gwcc: Set[Cell] = self.G.subgraph(oWCC) # Find all the stops in this chain or tree - svCompStops = svStops.intersection(Gwcc) - # print(svCompStops) + svCompStops: Set[Cell] = svStops.intersection(Gwcc) if len(svCompStops) > 0: - if reversed_G is None: - reversed_G = self.get_G_reversed() # We need to traverse it in reverse - back up the movement edges - Gwcc_rev = reversed_G.subgraph(oWCC) # Gwcc.reverse() + Gwcc_rev = reversed_G.subgraph(oWCC) for vStop in svCompStops: # Find all the agents stopped by vStop by following the (reversed) edges # This traverses a tree - dfs = depth first seearch @@ -119,49 +117,33 @@ def find_stop_preds(self, svStops=None): # the set of all the nodes/agents blocked by this set of stopped nodes return svBlocked - def find_swaps(self): - """ find all the swap conflicts where two agents are trying to exchange places. - These appear as simple cycles of length 2. - These agents are necessarily deadlocked (since they can't change direction in flatland) - - meaning they will now be stuck for the rest of the episode. + def find_swaps(self) -> Set[Cell]: + """ + Find loops of size 2 in the graph, i.e. swaps leading to head-on collisions. + :return: set of all cells in swaps. """ - # svStops = self.find_stops2() - llvLoops = list(nx.algorithms.cycles.simple_cycles(self.G)) - llvSwaps = [lvLoop for lvLoop in llvLoops if len(lvLoop) == 2] - svSwaps = {v for lvSwap in llvSwaps for v in lvSwap} - return svSwaps - - def find_swaps2(self) -> Set[Tuple[int, int]]: svSwaps = set() sEdges = self.G.edges() for u, v in sEdges: if u == v: - # print("self loop", u, v) pass else: if (v, u) in sEdges: - # print("swap", uv) svSwaps.update([u, v]) return svSwaps - def find_same_dest(self): - """ find groups of agents which are trying to land on the same cell. - ie there is a gap of one cell between them and they are both landing on it. - """ - pass - - def block_preds(self, svStops, color="red"): + def mark_preds(self, svStops: Set[Cell], color: object = "red") -> Set[Cell]: """ Take a list of stopped agents, and apply a stop color to any chains/trees of agents trying to head toward those cells. - Count the number of agents blocked, ignoring those which are already marked. - (Otherwise it can double count swaps) + :param svStops: list of stopped agents + :param color: color to apply to predecessor of stopped agents + :return: all predecessors of any stopped agent """ - iCount = 0 - svBlocked = set() + predecessors = set() if len(svStops) == 0: - return svBlocked + return predecessors # The reversed graph allows us to follow directed edges to find affected agents. Grev = self.get_G_reversed() @@ -169,49 +151,44 @@ def block_preds(self, svStops, color="red"): # Use depth-first-search to find a tree of agents heading toward the blocked cell. lvPred = list(nx.traversal.dfs_postorder_nodes(Grev, source=v)) - svBlocked |= set(lvPred) - svBlocked.add(v) - # print("node:", v, "set", svBlocked) - # only count those not already marked + predecessors |= set(lvPred) + predecessors.add(v) + + # only color those not already marked (not updating previous colors) for v2 in [v] + lvPred: if self.G.nodes[v2].get("color") != color: self.G.nodes[v2]["color"] = color - iCount += 1 - - return svBlocked + return predecessors def find_conflicts(self): - self.reset_G_reversed() + """Called in env.step() before the agents execute their actions.""" - svStops = self.find_stops2() # voluntarily stopped agents - have self-loops - # svSwaps = self.find_swaps() # deadlocks - adjacent head-on collisions - svSwaps = self.find_swaps2() # faster version of find_swaps + svStops: Set[Cell] = self.find_stopped_agents() # voluntarily stopped agents - have self-loops ("same cell to same cell") + svSwaps: Set[Cell] = self.find_swaps() # deadlocks - adjacent head-on collisions - # Block all swaps and their tree of predessors - self.svDeadlocked = self.block_preds(svSwaps, color="purple") + # Mark all swaps and their tree of predecessors with purple - these are directly deadlocked + self.svDeadlocked: Set[Cell] = self.mark_preds(svSwaps, color="purple") - # Take the union of the above, and find all the predecessors - # svBlocked = self.find_stop_preds(svStops.union(svSwaps)) - - # Just look for the tree of preds for each voluntarily stopped agent - svBlocked = self.find_stop_preds(svStops) + # Just look for the tree of preds for each voluntarily stopped agent (i.e. not wanting to move) + # TODO why not re-use mark_preds(swStops, color="red")? + # TODO refactoring suggestion: only one "blocked" red = 1. all deadlocked and their predecessor, 2. all predecessors of self-loopers, 3. + svBlocked: Set[Cell] = self.find_stop_preds(svStops) # iterate the nodes v with their predecessors dPred (dict of nodes->{}) for (v, dPred) in self.G.pred.items(): - # mark any swaps with purple - these are directly deadlocked - # if v in svSwaps: - # self.G.nodes[v]["color"] = "purple" - # If they are not directly deadlocked, but are in the union of stopped + deadlocked - # elif v in svBlocked: - # if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting + dPred: Set[Cell] = dPred + + # if in blocked, it will not also be in a swap pred tree, so no need to worry about overwriting (outdegree always <= 1!) + # TODO why not mark outside of the loop? The loop would then only need to go over nodes with indegree >2 not marked purple or red yet if v in svBlocked: self.G.nodes[v]["color"] = "red" + # not blocked but has two or more predecessors, ie >=2 agents waiting to enter this node elif len(dPred) > 1: + # if this agent is already red or purple, all its predecessors are in svDeadlocked or svBlocked and will eventually be marked red or purple - # if this agent is already red/blocked, ignore. CHECK: why? - # certainly we want to ignore purple so we don't overwrite with red. + # no conflict resolution if deadlocked or blocked if self.G.nodes[v].get("color") in ("red", "purple"): continue @@ -223,29 +200,28 @@ def find_conflicts(self): self.G.nodes[v]["color"] = "magenta" # predecessors of a contended cell: {agent index -> node} - diAgCell = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred} + diAgCell: Dict[AgentHandle, Cell] = {self.G.nodes[vPred].get("agent"): vPred for vPred in dPred} # remove the agent with the lowest index, who wins iAgWinner = min(diAgCell) diAgCell.pop(iAgWinner) - # Block all the remaining predessors, and their tree of preds - # for iAg, v in diAgCell.items(): - # self.G.nodes[v]["color"] = "red" - # for vPred in nx.traversal.dfs_postorder_nodes(self.G.reverse(), source=v): - # self.G.nodes[vPred]["color"] = "red" - self.block_preds(diAgCell.values(), "red") + self.mark_preds(set(diAgCell.values()), "red") - def check_motion(self, iAgent, rcPos): + def check_motion(self, iAgent: AgentHandle, rcPos: Cell) -> bool: """ Returns tuple of boolean can the agent move, and the cell it will move into. - If agent position is None, we use a dummy position of (-1, iAgent) + If agent position is None, we use a dummy position of (-1, iAgent). + Called in env.step() after conflicts are collected in find_conflicts() - each agent now can execute their position update independently (valid_movement) by calling check_motion. + :param iAgent: agent handle + :param rcPos: cell + :return: true iff the agent wants to move and it has no conflict """ if rcPos is None: + # no successor rcPos = (-1, iAgent) dAttr = self.G.nodes.get(rcPos) - # print("pos:", rcPos, "dAttr:", dAttr) if dAttr is None: dAttr = {} @@ -269,207 +245,3 @@ def check_motion(self, iAgent, rcPos): return False # The agent wanted to move, and it can return True - - -def render(omc: MotionCheck, horizontal=True): - try: - oAG = nx.drawing.nx_agraph.to_agraph(omc.G) - oAG.layout("dot") - sDot = oAG.to_string() - if horizontal: - sDot = sDot.replace('{', '{ rankdir="LR" ') - # return oAG.draw(format="png") - # This returns a graphviz object which implements __repr_svg - return gv.Source(sDot) - except ImportError as oError: - print("Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs") - return None - - -class ChainTestEnv(object): - """ Just for testing agent chains - """ - - def __init__(self, omc: MotionCheck): - self.iAgNext = 0 - self.iRowNext = 1 - self.omc = omc - - def addAgent(self, rc1, rc2, xlabel=None): - self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel) - self.iAgNext += 1 - - def addAgentToRow(self, c1, c2, xlabel=None): - self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel) - - def create_test_chain(self, - nAgents: int, - rcVel: Tuple[int] = (0, 1), - liStopped: List[int] = [], - xlabel=None): - """ create a chain of agents - """ - lrcAgPos = [(self.iRowNext, i * rcVel[1]) for i in range(nAgents)] - - for iAg, rcPos in zip(range(nAgents), lrcAgPos): - if iAg in liStopped: - rcVel1 = (0, 0) - else: - rcVel1 = rcVel - self.omc.addAgent(iAg + self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1])) - - if xlabel: - self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel - - self.iAgNext += nAgents - self.iRowNext += 1 - - def nextRow(self): - self.iRowNext += 1 - - -def create_test_agents(omc: MotionCheck): - # blocked chain - omc.addAgent(1, (1, 2), (1, 3)) - omc.addAgent(2, (1, 3), (1, 4)) - omc.addAgent(3, (1, 4), (1, 5)) - omc.addAgent(31, (1, 5), (1, 5)) - - # unblocked chain - omc.addAgent(4, (2, 1), (2, 2)) - omc.addAgent(5, (2, 2), (2, 3)) - - # blocked short chain - omc.addAgent(6, (3, 1), (3, 2)) - omc.addAgent(7, (3, 2), (3, 2)) - - # solitary agent - omc.addAgent(8, (4, 1), (4, 2)) - - # solitary stopped agent - omc.addAgent(9, (5, 1), (5, 1)) - - # blocked short chain (opposite direction) - omc.addAgent(10, (6, 4), (6, 3)) - omc.addAgent(11, (6, 3), (6, 3)) - - # swap conflict - omc.addAgent(12, (7, 1), (7, 2)) - omc.addAgent(13, (7, 2), (7, 1)) - - -def create_test_agents2(omc: MotionCheck): - # blocked chain - cte = ChainTestEnv(omc) - cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain") - cte.create_test_chain(4, xlabel="running\nchain") - - cte.create_test_chain(2, liStopped=[1], xlabel="stopped \nshort\n chain") - - cte.addAgentToRow(1, 2, "swap") - cte.addAgentToRow(2, 1) - - cte.nextRow() - - cte.addAgentToRow(1, 2, "chain\nswap") - cte.addAgentToRow(2, 3) - cte.addAgentToRow(3, 2) - - cte.nextRow() - - cte.addAgentToRow(1, 2, "midchain\nstop") - cte.addAgentToRow(2, 3) - cte.addAgentToRow(3, 4) - cte.addAgentToRow(4, 4) - cte.addAgentToRow(5, 6) - cte.addAgentToRow(6, 7) - - cte.nextRow() - - cte.addAgentToRow(1, 2, "midchain\nswap") - cte.addAgentToRow(2, 3) - cte.addAgentToRow(3, 4) - cte.addAgentToRow(4, 3) - cte.addAgentToRow(5, 4) - cte.addAgentToRow(6, 5) - - cte.nextRow() - - cte.addAgentToRow(1, 2, "Land on\nSame") - cte.addAgentToRow(3, 2) - - cte.nextRow() - cte.addAgentToRow(1, 2, "chains\nonto\nsame") - cte.addAgentToRow(2, 3) - cte.addAgentToRow(3, 4) - cte.addAgentToRow(5, 4) - cte.addAgentToRow(6, 5) - cte.addAgentToRow(7, 6) - - cte.nextRow() - cte.addAgentToRow(1, 2, "3-way\nsame") - cte.addAgentToRow(3, 2) - cte.addAgent((cte.iRowNext + 1, 2), (cte.iRowNext, 2)) - cte.nextRow() - - if False: - cte.nextRow() - cte.nextRow() - cte.addAgentToRow(1, 2, "4-way\nsame") - cte.addAgentToRow(3, 2) - cte.addAgent((cte.iRowNext + 1, 2), (cte.iRowNext, 2)) - cte.addAgent((cte.iRowNext - 1, 2), (cte.iRowNext, 2)) - cte.nextRow() - - cte.nextRow() - cte.addAgentToRow(1, 2, "Tee") - cte.addAgentToRow(2, 3) - cte.addAgentToRow(3, 4) - cte.addAgent((cte.iRowNext + 1, 3), (cte.iRowNext, 3)) - cte.nextRow() - - cte.nextRow() - cte.addAgentToRow(1, 2, "Tree") - cte.addAgentToRow(2, 3) - cte.addAgentToRow(3, 4) - r1 = cte.iRowNext - r2 = cte.iRowNext + 1 - r3 = cte.iRowNext + 2 - cte.addAgent((r2, 3), (r1, 3)) - cte.addAgent((r2, 2), (r2, 3)) - cte.addAgent((r3, 2), (r2, 3)) - - cte.nextRow() - - -def test_agent_following(): - omc = MotionCheck() - create_test_agents2(omc) - - svStops = omc.find_stops() - svBlocked = omc.find_stop_preds() - llvSwaps = omc.find_swaps() - svSwaps = {v for lvSwap in llvSwaps for v in lvSwap} - print(list(svBlocked)) - - lvCells = omc.G.nodes() - - lColours = ["magenta" if v in svStops - else "red" if v in svBlocked - else "purple" if v in svSwaps - else "lightblue" - for v in lvCells] - dPos = dict(zip(lvCells, lvCells)) - - nx.draw(omc.G, - with_labels=True, arrowsize=20, - pos=dPos, - node_color=lColours) - - -def main(): - test_agent_following() - - -if __name__ == "__main__": - main() diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a6aeea16..5bcd6f81 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -5,7 +5,6 @@ from typing import List, Optional, Dict, Tuple import numpy as np -from flatland.utils import seeding # from flatland.envs.timetable_generators import timetable_generator import flatland.envs.timetable_generators as ttg @@ -27,6 +26,7 @@ from flatland.envs.step_utils import env_utils from flatland.envs.step_utils.states import TrainState, StateTransitionSignals from flatland.envs.step_utils.transition_utils import check_valid_action +from flatland.utils import seeding from flatland.utils.decorators import send_infrastructure_data_change_signal_to_reset_lru_cache, \ enable_infrastructure_lru_cache from flatland.utils.rendertools import RenderTool, AgentRenderVariant @@ -52,9 +52,12 @@ class RailEnv(Environment): Moving forward in a dead-end cell makes the agent turn 180 degrees and step to the cell it came from. + In order for agents to be able to "understand" the simulation behaviour from the observations, + the execution order of actions should not matter (i.e. not depend on the agent handle). + However, the agent ordering is still used to resolve conflicts between two agents trying to move into the same cell, + for example, head-on collisions, or agents "merging" at junctions. + See `MotionCheck` for more details. - The actions of the agents are executed in order of their handle to prevent - deadlocks and to allow them to learn relative priorities. Reward Function: diff --git a/notebooks/Agent-Close-Following.ipynb b/notebooks/Agent-Close-Following.ipynb index 2224ddd9..e8fa6b3c 100644 --- a/notebooks/Agent-Close-Following.ipynb +++ b/notebooks/Agent-Close-Following.ipynb @@ -94,7 +94,30 @@ "from flatland.envs.persistence import RailEnvPersister\n", "from flatland.utils.rendertools import RenderTool\n", "from flatland.utils import env_edit_utils as eeu\n", - "from flatland.utils import jupyter_utils as ju" + "from flatland.utils import jupyter_utils as ju\n", + "from tests.test_agent_chains import create_test_agents2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz as gv\n", + "def render(omc: ac.MotionCheck, horizontal=True):\n", + " try:\n", + " oAG = nx.drawing.nx_agraph.to_agraph(omc.G)\n", + " oAG.layout(\"dot\")\n", + " sDot = oAG.to_string()\n", + " if horizontal:\n", + " sDot = sDot.replace('{', '{ rankdir=\"LR\" ')\n", + " # return oAG.draw(format=\"png\")\n", + " # This returns a graphviz object which implements __repr_svg\n", + " return gv.Source(sDot)\n", + " except ImportError as oError:\n", + " print(\"Flatland agent_chains ignoring ImportError - install pygraphviz to render graphs\")\n", + " return None" ] }, { @@ -113,8 +136,8 @@ "outputs": [], "source": [ "omc = ac.MotionCheck()\n", - "ac.create_test_agents2(omc)\n", - "rv = ac.render(omc)\n", + "create_test_agents2(omc)\n", + "rv = render(omc)\n", "print(type(rv))" ] }, @@ -147,7 +170,7 @@ "metadata": {}, "outputs": [], "source": [ - "gvDot = ac.render(omc)\n", + "gvDot = render(omc)\n", "gvDot" ] }, @@ -257,7 +280,7 @@ " oEC.render()\n", " \n", " display.display_html(f\"
Step: {i}\\n\", raw=True)\n", - " display.display_svg(ac.render(env.motionCheck, horizontal=(i>=3)))\n", + " display.display_svg(render(env.motionCheck, horizontal=(i>=3)))\n", " time.sleep(0.1)" ] }, @@ -375,6 +398,15 @@ "env.motionCheck.svDeadlocked" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "render(env.motionCheck)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -442,7 +474,7 @@ "metadata": { "hide_input": false, "kernelspec": { - "display_name": "ve310fl", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -456,7 +488,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.15" }, "latex_envs": { "LaTeX_envs_menu_present": true, @@ -607,5 +639,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/tests/test_agent_chains.py b/tests/test_agent_chains.py new file mode 100644 index 00000000..80af2b55 --- /dev/null +++ b/tests/test_agent_chains.py @@ -0,0 +1,207 @@ +from typing import Tuple, List + +import networkx as nx + +from flatland.envs.agent_chains import MotionCheck + + +def create_test_agents(omc: MotionCheck): + # blocked chain + omc.addAgent(1, (1, 2), (1, 3)) + omc.addAgent(2, (1, 3), (1, 4)) + omc.addAgent(3, (1, 4), (1, 5)) + omc.addAgent(31, (1, 5), (1, 5)) + + # unblocked chain + omc.addAgent(4, (2, 1), (2, 2)) + omc.addAgent(5, (2, 2), (2, 3)) + + # blocked short chain + omc.addAgent(6, (3, 1), (3, 2)) + omc.addAgent(7, (3, 2), (3, 2)) + + # solitary agent + omc.addAgent(8, (4, 1), (4, 2)) + + # solitary stopped agent + omc.addAgent(9, (5, 1), (5, 1)) + + # blocked short chain (opposite direction) + omc.addAgent(10, (6, 4), (6, 3)) + omc.addAgent(11, (6, 3), (6, 3)) + + # swap conflict + omc.addAgent(12, (7, 1), (7, 2)) + omc.addAgent(13, (7, 2), (7, 1)) + + +class ChainTestEnv(object): + """ Just for testing agent chains + """ + + def __init__(self, omc: MotionCheck): + self.iAgNext = 0 + self.iRowNext = 1 + self.omc = omc + + def addAgent(self, rc1, rc2, xlabel=None): + self.omc.addAgent(self.iAgNext, rc1, rc2, xlabel=xlabel) + self.iAgNext += 1 + + def addAgentToRow(self, c1, c2, xlabel=None): + self.addAgent((self.iRowNext, c1), (self.iRowNext, c2), xlabel=xlabel) + + def create_test_chain(self, + nAgents: int, + rcVel: Tuple[int] = (0, 1), + liStopped: List[int] = [], + xlabel=None): + """ create a chain of agents + """ + lrcAgPos = [(self.iRowNext, i * rcVel[1]) for i in range(nAgents)] + + for iAg, rcPos in zip(range(nAgents), lrcAgPos): + if iAg in liStopped: + rcVel1 = (0, 0) + else: + rcVel1 = rcVel + self.omc.addAgent(iAg + self.iAgNext, rcPos, (rcPos[0] + rcVel1[0], rcPos[1] + rcVel1[1])) + + if xlabel: + self.omc.G.nodes[lrcAgPos[0]]["xlabel"] = xlabel + + self.iAgNext += nAgents + self.iRowNext += 1 + + def nextRow(self): + self.iRowNext += 1 + + +def create_test_agents2(omc: MotionCheck): + # blocked chain + cte = ChainTestEnv(omc) + cte.create_test_chain(4, liStopped=[3], xlabel="stopped\nchain") + cte.create_test_chain(4, xlabel="running\nchain") + + cte.create_test_chain(2, liStopped=[1], xlabel="stopped \nshort\n chain") + + cte.addAgentToRow(1, 2, "swap") + cte.addAgentToRow(2, 1) + + cte.nextRow() + + cte.addAgentToRow(1, 2, "chain\nswap") + cte.addAgentToRow(2, 3) + cte.addAgentToRow(3, 2) + + cte.nextRow() + + cte.addAgentToRow(1, 2, "midchain\nstop") + cte.addAgentToRow(2, 3) + cte.addAgentToRow(3, 4) + cte.addAgentToRow(4, 4) + cte.addAgentToRow(5, 6) + cte.addAgentToRow(6, 7) + + cte.nextRow() + + cte.addAgentToRow(1, 2, "midchain\nswap") + cte.addAgentToRow(2, 3) + cte.addAgentToRow(3, 4) + cte.addAgentToRow(4, 3) + cte.addAgentToRow(5, 4) + cte.addAgentToRow(6, 5) + + cte.nextRow() + + cte.addAgentToRow(1, 2, "Land on\nSame") + cte.addAgentToRow(3, 2) + + cte.nextRow() + cte.addAgentToRow(1, 2, "chains\nonto\nsame") + cte.addAgentToRow(2, 3) + cte.addAgentToRow(3, 4) + cte.addAgentToRow(5, 4) + cte.addAgentToRow(6, 5) + cte.addAgentToRow(7, 6) + + cte.nextRow() + cte.addAgentToRow(1, 2, "3-way\nsame") + cte.addAgentToRow(3, 2) + cte.addAgent((cte.iRowNext + 1, 2), (cte.iRowNext, 2)) + cte.nextRow() + + + cte.nextRow() + cte.addAgentToRow(1, 2, "Tee") + cte.addAgentToRow(2, 3) + cte.addAgentToRow(3, 4) + cte.addAgent((cte.iRowNext + 1, 3), (cte.iRowNext, 3)) + cte.nextRow() + + cte.nextRow() + cte.addAgentToRow(1, 2, "Tree") + cte.addAgentToRow(2, 3) + cte.addAgentToRow(3, 4) + r1 = cte.iRowNext + r2 = cte.iRowNext + 1 + r3 = cte.iRowNext + 2 + cte.addAgent((r2, 3), (r1, 3)) + cte.addAgent((r2, 2), (r2, 3)) + cte.addAgent((r3, 2), (r2, 3)) + + cte.nextRow() + + +def test_agent_following(): + expected = { + (1, 0): "red", + (1, 1): "red", + (1, 2): "red", + (1, 3): "red", + (3, 0): "red", + (3, 1): "red", + (4, 1): "purple", + (4, 2): "purple", + (5, 1): "purple", + (5, 2): "purple", + (5, 3): "purple", + (6, 1): "red", + (6, 2): "red", + (6, 3): "red", + (6, 4): "red", + (7, 1): "purple", + (7, 2): "purple", + (7, 3): "purple", + (7, 4): "purple", + (7, 5): "purple", + (7, 6): "purple", + (8, 2): "blue", + (8, 3): "red", + (9, 4): "blue", + (9, 5): "red", + (9, 6): "red", + (9, 7): "red", + (10, 2): "blue", + (10, 3): "red", + (11, 2): "red", + (12, 3): "magenta", + (13, 3): "red", + (14, 3): "magenta", + (15, 3): "red", + (15, 2): "red", + (16, 2): "red", + } + omc = MotionCheck() + create_test_agents2(omc) + omc.find_conflicts() + nx.draw(omc.G, + with_labels=True, arrowsize=20, + pos={p: p for p in omc.G.nodes}, + node_color=[n["color"] if "color" in n else "lightblue" for _, n in omc.G.nodes.data()] + ) + actual = {i: n['color'] for i, n in omc.G.nodes.data() if 'color' in n} + + assert set(actual.keys()) == set(expected.keys()) + for k in actual.keys(): + assert expected[k] == actual[k], (k, expected[k], actual[k])