Skip to content

Commit

Permalink
Show number of tasks prominently on graph visualization (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Feb 5, 2024
1 parent 8831b94 commit 8f11466
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def visualize(

# now set node attributes with visualization info
for n, d in dag.nodes(data=True):
label = n
tooltip = f"name: {n}\n"
node_type = d.get("type", None)
if node_type == "op":
Expand All @@ -318,12 +319,14 @@ def visualize(
op_name_summary = ""
tooltip += f"op: {op_name}"

num_tasks = None
if "primitive_op" in d:
primitive_op = d["primitive_op"]
tooltip += (
f"\nprojected memory: {memory_repr(primitive_op.projected_mem)}"
)
tooltip += f"\ntasks: {primitive_op.num_tasks}"
num_tasks = primitive_op.num_tasks
tooltip += f"\ntasks: {num_tasks}"
if primitive_op.write_chunks is not None:
tooltip += f"\nwrite chunks: {primitive_op.write_chunks}"
del d["primitive_op"]
Expand All @@ -347,7 +350,7 @@ def visualize(
first_cubed_summary = stack_summaries[first_cubed_i]
caller_summary = stack_summaries[first_cubed_i - 1]

d["label"] = f"{first_cubed_summary.name} {op_name_summary}"
label = f"{first_cubed_summary.name} {op_name_summary}"

calls = " -> ".join(
[
Expand All @@ -363,6 +366,9 @@ def visualize(
tooltip += f"\nline: {line}"
del d["stack_summaries"]

if num_tasks is not None:
label += f"\ntasks: {num_tasks}"

elif node_type == "array":
target = d["target"]
chunkmem = memory_repr(chunk_memory(target.dtype, target.chunks))
Expand All @@ -375,10 +381,8 @@ def visualize(
nbytes = memory_repr(target.nbytes)
if n in array_display_names:
var_name = array_display_names[n]
d["label"] = f"{n} ({var_name})"
label = f"{n} ({var_name})"
tooltip += f"variable: {var_name}\n"
else:
d["label"] = n
tooltip += f"shape: {target.shape}\n"
tooltip += f"chunks: {target.chunks}\n"
tooltip += f"dtype: {target.dtype}\n"
Expand All @@ -388,6 +392,7 @@ def visualize(

del d["target"]

d["label"] = label.strip()
d["tooltip"] = tooltip.strip()

if "name" in d: # pydot already has name
Expand Down

0 comments on commit 8f11466

Please sign in to comment.