diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 7d7f43b4..897f1595 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -71,6 +71,9 @@ "output_spec": "output spec", }) +_MAX_COLUMN_WIDTH = 30 +_MAX_ROW_LINES = 5 + def _pretty_format_ns(value: int) -> str: """Pretty formats a time value in nanoseconds to human readable value.""" @@ -103,11 +106,19 @@ def _pretty_format_summary( """Returns Execution Stats Summary for the dataset pipeline in tabular format.""" tabular_summary = [] col_names = [key for key in summary.nodes[0].DESCRIPTOR.fields_by_name.keys()] + # Remove the columns `output_spec` and `is_output` as they are available in + # the visualization graph. + col_names.remove("output_spec") + col_names.remove("is_output") # Insert the average processing time column after the max processing time # column. index = col_names.index("max_processing_time_ns") col_names.insert(index + 1, _AVG_PROCESSING_TIME_COLUMN_NAME) + tabular_summary.append( + [_COLUMN_NAME_OVERRIDES.get(name, name) for name in col_names] + ) + # Compute the maximum width of each column. col_widths = [] for name in col_names: @@ -120,24 +131,12 @@ def _pretty_format_summary( max_width = max(len(str(value)), max_width) col_widths.append(max_width) - col_headers = [ - "| {:<{}} |".format(_COLUMN_NAME_OVERRIDES.get(name, name), width) - for name, width in zip(col_names, col_widths) - ] - col_seperators = ["-" * len(header) for header in col_headers] - - tabular_summary.extend(col_seperators) - tabular_summary.append("\n") - tabular_summary.extend(col_headers) - tabular_summary.append("\n") - tabular_summary.extend(col_seperators) - tabular_summary.append("\n") - for node_id in sorted(summary.nodes, reverse=True): - is_total_processing_time_zero = ( - summary.nodes[node_id].total_processing_time_ns == 0 - ) - for name, width in zip(col_names, col_widths): + row_values = [] + for name in col_names: + is_total_processing_time_zero = ( + summary.nodes[node_id].total_processing_time_ns == 0 + ) if name == _AVG_PROCESSING_TIME_COLUMN_NAME: value = _get_avg_processing_time_ns(summary, node_id) else: @@ -154,19 +153,72 @@ def _pretty_format_summary( # produced an element and processing times & num_produced_elements are # not yet meaningful. if is_total_processing_time_zero: - col_value = f"{f'| N/A':<{width+2}} |" + col_value = "N/A" elif name != "num_produced_elements": - col_value = f"{f'| {_pretty_format_ns(value)}':<{width+2}} |" + col_value = _pretty_format_ns(value) else: - col_value = f"{f'| {value}':<{width+2}} |" + col_value = str(value) else: - col_value = "| {:<{}} |".format(str(value), width) - tabular_summary.append(col_value) - tabular_summary.append("\n") + col_value = str(value) + row_values.append(col_value) + tabular_summary.append(row_values) + table = Table(tabular_summary, col_widths=col_widths) + return table.get_pretty_wrapped_summary() + + +class Table: + """Table class for pretty printing tabular data.""" + + def __init__( + self, + contents, + *, + col_widths, + col_delim="|", + row_delim="-", + ): - for seperator in col_seperators: - tabular_summary.append(seperator) - return "".join(tabular_summary) + self.contents = contents + self._max_col_width = _MAX_COLUMN_WIDTH + self.col_delim = col_delim + self.col_widths = col_widths + self._pretty_summary = [] + self.col_header = [] + + p = len(self.col_delim) * (len(self.contents[0]) - 1) + + self.col_header.append(self.col_delim) + for col_width in self.col_widths: + if col_width > self._max_col_width: + col_width = self._max_col_width + self.col_header.append(row_delim * (col_width + 2)) + self.col_header.append(row_delim * (p)) + self.col_header.append(self.col_delim + "\n") + self._pretty_summary.extend(self.col_header) + + def get_pretty_wrapped_summary(self): + """Wraps the table contents within the max column width and max row lines.""" + + for row in self.contents: + max_wrap = (max([len(i) for i in row]) // self._max_col_width) + 1 + max_wrap = min(max_wrap, _MAX_ROW_LINES) + for r in range(max_wrap): + self._pretty_summary.append(self.col_delim) + for index in range(len(row)): + if self.col_widths[index] > self._max_col_width: + wrap = self._max_col_width + else: + wrap = self.col_widths[index] + start = r * self._max_col_width + end = (r + 1) * self._max_col_width + self._pretty_summary.append(" ") + self._pretty_summary.append(row[index][start:end].ljust(wrap)) + self._pretty_summary.append(" ") + self._pretty_summary.append(self.col_delim) + self._pretty_summary.append("\n") + self._pretty_summary.extend(self.col_header) + + return "".join(self._pretty_summary) class Timer: