Skip to content

Commit

Permalink
#grain-debug-tool Pretty print execution summary table.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693174828
  • Loading branch information
Grain Team authored and copybara-github committed Nov 11, 2024
1 parent c662b83 commit 5b3b1a3
Showing 1 changed file with 78 additions and 26 deletions.
104 changes: 78 additions & 26 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
"output_spec": "output spec",
})

_MAX_COLUMN_WIDTH = 30
_MAX_ROW_LINES = 5


def _pretty_format_ns(value: int) -> str:
"""Pretty formats a time value in nanoseconds to human readable value."""
Expand Down Expand Up @@ -103,11 +106,19 @@ def _pretty_format_summary(
"""Returns Execution Stats Summary for the dataset pipeline in tabular format."""
tabular_summary = []
col_names = [key for key in summary.nodes[0].DESCRIPTOR.fields_by_name.keys()]
# Remove the columns `output_spec` and `is_output` as they are available in
# the visualization graph.
col_names.remove("output_spec")
col_names.remove("is_output")
# Insert the average processing time column after the max processing time
# column.
index = col_names.index("max_processing_time_ns")
col_names.insert(index + 1, _AVG_PROCESSING_TIME_COLUMN_NAME)

tabular_summary.append(
[_COLUMN_NAME_OVERRIDES.get(name, name) for name in col_names]
)

# Compute the maximum width of each column.
col_widths = []
for name in col_names:
Expand All @@ -120,24 +131,12 @@ def _pretty_format_summary(
max_width = max(len(str(value)), max_width)
col_widths.append(max_width)

col_headers = [
"| {:<{}} |".format(_COLUMN_NAME_OVERRIDES.get(name, name), width)
for name, width in zip(col_names, col_widths)
]
col_seperators = ["-" * len(header) for header in col_headers]

tabular_summary.extend(col_seperators)
tabular_summary.append("\n")
tabular_summary.extend(col_headers)
tabular_summary.append("\n")
tabular_summary.extend(col_seperators)
tabular_summary.append("\n")

for node_id in sorted(summary.nodes, reverse=True):
is_total_processing_time_zero = (
summary.nodes[node_id].total_processing_time_ns == 0
)
for name, width in zip(col_names, col_widths):
row_values = []
for name in col_names:
is_total_processing_time_zero = (
summary.nodes[node_id].total_processing_time_ns == 0
)
if name == _AVG_PROCESSING_TIME_COLUMN_NAME:
value = _get_avg_processing_time_ns(summary, node_id)
else:
Expand All @@ -154,19 +153,72 @@ def _pretty_format_summary(
# produced an element and processing times & num_produced_elements are
# not yet meaningful.
if is_total_processing_time_zero:
col_value = f"{f'| N/A':<{width+2}} |"
col_value = "N/A"
elif name != "num_produced_elements":
col_value = f"{f'| {_pretty_format_ns(value)}':<{width+2}} |"
col_value = _pretty_format_ns(value)
else:
col_value = f"{f'| {value}':<{width+2}} |"
col_value = str(value)
else:
col_value = "| {:<{}} |".format(str(value), width)
tabular_summary.append(col_value)
tabular_summary.append("\n")
col_value = str(value)
row_values.append(col_value)
tabular_summary.append(row_values)
table = Table(tabular_summary, col_widths=col_widths)
return table.get_pretty_wrapped_summary()


class Table:
"""Table class for pretty printing tabular data."""

def __init__(
self,
contents,
*,
col_widths,
col_delim="|",
row_delim="-",
):

for seperator in col_seperators:
tabular_summary.append(seperator)
return "".join(tabular_summary)
self.contents = contents
self._max_col_width = _MAX_COLUMN_WIDTH
self.col_delim = col_delim
self.col_widths = col_widths
self._pretty_summary = []
self.col_header = []

p = len(self.col_delim) * (len(self.contents[0]) - 1)

self.col_header.append(self.col_delim)
for col_width in self.col_widths:
if col_width > self._max_col_width:
col_width = self._max_col_width
self.col_header.append(row_delim * (col_width + 2))
self.col_header.append(row_delim * (p))
self.col_header.append(self.col_delim + "\n")
self._pretty_summary.extend(self.col_header)

def get_pretty_wrapped_summary(self):
"""Wraps the table contents within the max column width and max row lines."""

for row in self.contents:
max_wrap = (max([len(i) for i in row]) // self._max_col_width) + 1
max_wrap = min(max_wrap, _MAX_ROW_LINES)
for r in range(max_wrap):
self._pretty_summary.append(self.col_delim)
for index in range(len(row)):
if self.col_widths[index] > self._max_col_width:
wrap = self._max_col_width
else:
wrap = self.col_widths[index]
start = r * self._max_col_width
end = (r + 1) * self._max_col_width
self._pretty_summary.append(" ")
self._pretty_summary.append(row[index][start:end].ljust(wrap))
self._pretty_summary.append(" ")
self._pretty_summary.append(self.col_delim)
self._pretty_summary.append("\n")
self._pretty_summary.extend(self.col_header)

return "".join(self._pretty_summary)


class Timer:
Expand Down

0 comments on commit 5b3b1a3

Please sign in to comment.