Skip to content

Commit

Permalink
Merge branch 'master' into vigneshka/head-node-tag
Browse files Browse the repository at this point in the history
Signed-off-by: Vignesh Hirudayakanth <vignesh@anyscale.com>
  • Loading branch information
vigneshka committed Oct 24, 2024
2 parents 0f9efb0 + adcdfe8 commit 0dfa402
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 0 deletions.
129 changes: 129 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,135 @@ async def execute_async(
self._execution_index += 1
return fut

def visualize(
self, filename="compiled_graph", format="png", view=False, return_dot=False
):
"""
Visualize the compiled graph using Graphviz.
This method generates a graphical representation of the compiled graph,
showing tasks and their dependencies.This method should be called
**after** the graph has been compiled using `experimental_compile()`.
Args:
filename: The name of the output file (without extension).
format: The format of the output file (e.g., 'png', 'pdf').
view: Whether to open the file with the default viewer.
return_dot: If True, returns the DOT source as a string instead of figure.
Raises:
ValueError: If the graph is empty or not properly compiled.
ImportError: If the `graphviz` package is not installed.
"""
import graphviz
from ray.dag import (
InputAttributeNode,
InputNode,
MultiOutputNode,
ClassMethodNode,
DAGNode,
)

# Check that the DAG has been compiled
if not hasattr(self, "idx_to_task") or not self.idx_to_task:
raise ValueError(
"The DAG must be compiled before calling 'visualize()'. "
"Please call 'experimental_compile()' first."
)

# Check that each CompiledTask has a valid dag_node
for idx, task in self.idx_to_task.items():
if not hasattr(task, "dag_node") or not isinstance(task.dag_node, DAGNode):
raise ValueError(
f"Task at index {idx} does not have a valid 'dag_node'. "
"Ensure that 'experimental_compile()' completed successfully."
)

# Dot file for debuging
dot = graphviz.Digraph(name="compiled_graph", format=format)

# Add nodes with task information
for idx, task in self.idx_to_task.items():
dag_node = task.dag_node

# Initialize the label and attributes
label = f"Task {idx}\n"
shape = "oval" # Default shape
style = "filled"
fillcolor = ""

# Handle different types of dag_node
if isinstance(dag_node, InputNode):
label += "InputNode"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, InputAttributeNode):
label += f"InputAttributeNode[{dag_node.key}]"
shape = "rectangle"
fillcolor = "lightblue"
elif isinstance(dag_node, MultiOutputNode):
label += "MultiOutputNode"
shape = "rectangle"
fillcolor = "yellow"
elif isinstance(dag_node, ClassMethodNode):
if dag_node.is_class_method_call:
# Class Method Call Node
method_name = dag_node.get_method_name()
actor_handle = dag_node._get_actor_handle()
if actor_handle:
actor_id = actor_handle._actor_id.hex()
label += f"Actor: {actor_id[:6]}...\nMethod: {method_name}"
else:
label += f"Method: {method_name}"
shape = "oval"
fillcolor = "lightgreen"
elif dag_node.is_class_method_output:
# Class Method Output Node
label += f"ClassMethodOutputNode[{dag_node.output_idx}]"
shape = "rectangle"
fillcolor = "orange"
else:
# Unexpected ClassMethodNode
label += "ClassMethodNode"
shape = "diamond"
fillcolor = "red"
else:
# Unexpected node type
label += type(dag_node).__name__
shape = "diamond"
fillcolor = "red"

# Add the node to the graph with attributes
dot.node(str(idx), label, shape=shape, style=style, fillcolor=fillcolor)

# Add edges with type hints based on argument mappings
for idx, task in self.idx_to_task.items():
current_task_idx = idx

for arg_index, arg in enumerate(task.dag_node.get_args()):
if isinstance(arg, DAGNode):
# Get the upstream task index
upstream_task_idx = self.dag_node_to_idx[arg]

# Get the type hint for this argument
if arg_index < len(task.arg_type_hints):
type_hint = type(task.arg_type_hints[arg_index]).__name__
else:
type_hint = "UnknownType"

# Draw an edge from the upstream task to the
# current task with the type hint
dot.edge(
str(upstream_task_idx), str(current_task_idx), label=type_hint
)

if return_dot:
return dot.source
else:
# Render the graph to a file
dot.render(filename, view=view)

def teardown(self):
"""Teardown and cancel all actor tasks for this DAG. After this
function returns, the actors should be available to execute new tasks
Expand Down
171 changes: 171 additions & 0 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@

logger = logging.getLogger(__name__)

try:
import pydot
except Exception:
logging.info("pydot is not installed, visualization tests will be skiped")

pytestmark = [
pytest.mark.skipif(
Expand Down Expand Up @@ -2493,6 +2497,173 @@ async def main():
compiled_dag.teardown()


class TestVisualization:

"""Tests for the visualize method of compiled DAGs."""

# TODO(zhilong): "pip intsall pydot"
# and "sudo apt-get install graphviz " to run test.
@pytest.fixture(autouse=True)
def skip_if_pydot_graphviz_not_available(self):
# Skip the test if pydot or graphviz is not available
pytest.importorskip("pydot")
pytest.importorskip("graphviz")

def test_visualize_basic(self, ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1 [label=SharedMemoryType]
1 -> 2 [label=SharedMemoryType]
"""

@ray.remote
class Actor:
def echo(self, x):
return x

actor = Actor.remote()

with InputNode() as i:
dag = actor.echo.bind(i)

compiled_dag = dag.experimental_compile()

# Call the visualize method
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {("0", "1"), ("1", "2")}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()

def test_visualize_multi_return(self, ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1 [label=SharedMemoryType]
1 -> 2 [label=SharedMemoryType]
1 -> 3 [label=SharedMemoryType]
2 -> 4 [label=SharedMemoryType]
3 -> 4 [label=SharedMemoryType]
"""

@ray.remote
class Actor:
@ray.method(num_returns=2)
def return_two(self, x):
return x, x + 1

actor = Actor.remote()

with InputNode() as i:
o1, o2 = actor.return_two.bind(i)
dag = MultiOutputNode([o1, o2])

compiled_dag = dag.experimental_compile()

# Get the DOT source
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2", "3", "4"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {("0", "1"), ("1", "2"), ("1", "3"), ("2", "4"), ("3", "4")}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()

def test_visualize_multi_return2(self, ray_start_regular):
"""
Expect output or dot_source:
MultiOutputNode" fillcolor=yellow shape=rectangle style=filled]
0 -> 1 [label=SharedMemoryType]
1 -> 2 [label=SharedMemoryType]
1 -> 3 [label=SharedMemoryType]
2 -> 4 [label=SharedMemoryType]
3 -> 5 [label=SharedMemoryType]
4 -> 6 [label=SharedMemoryType]
5 -> 6 [label=SharedMemoryType]
"""

@ray.remote
class Actor:
@ray.method(num_returns=2)
def return_two(self, x):
return x, x + 1

def echo(self, x):
return x

a = Actor.remote()
b = Actor.remote()
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
o3 = b.echo.bind(o1)
o4 = b.echo.bind(o2)
dag = MultiOutputNode([o3, o4])

compiled_dag = dag.experimental_compile()

# Get the DOT source
dot_source = compiled_dag.visualize(return_dot=True)

graphs = pydot.graph_from_dot_data(dot_source)
graph = graphs[0]

node_names = {node.get_name() for node in graph.get_nodes()}
edge_pairs = {
(edge.get_source(), edge.get_destination()) for edge in graph.get_edges()
}

expected_nodes = {"0", "1", "2", "3", "4", "5", "6"}
assert expected_nodes.issubset(
node_names
), f"Expected nodes {expected_nodes} not found."

expected_edges = {
("0", "1"),
("1", "2"),
("1", "3"),
("2", "4"),
("3", "5"),
("4", "6"),
("5", "6"),
}
assert expected_edges.issubset(
edge_pairs
), f"Expected edges {expected_edges} not found."

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down

0 comments on commit 0dfa402

Please sign in to comment.