diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index fc3245d8..7aaa2a26 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -31,6 +31,7 @@ py_library( deps = [ ":base", ":stats", + "//grain/_src/core:config", "//grain/_src/core:exceptions", "//grain/_src/core:monitoring", "//grain/_src/core:transforms", diff --git a/grain/_src/python/dataset/stats.py b/grain/_src/python/dataset/stats.py index 7d7f43b4..64b9e0c6 100644 --- a/grain/_src/python/dataset/stats.py +++ b/grain/_src/python/dataset/stats.py @@ -18,6 +18,7 @@ import abc import contextlib import dataclasses +import os import pprint import sys import threading @@ -71,6 +72,15 @@ "output_spec": "output spec", }) +_MAX_COLUMN_WIDTH = 30 +_MAX_ROW_LINES = 5 + +# We use a fixed-size length encoding for sending variable-length messages. +_ENCODED_LENGTH_SIZE = 10 + +# Maximum size of the buffer for large reads from fifo. +_READ_BUFFER_SIZE = 1024000 + def _pretty_format_ns(value: int) -> str: """Pretty formats a time value in nanoseconds to human readable value.""" @@ -103,11 +113,19 @@ def _pretty_format_summary( """Returns Execution Stats Summary for the dataset pipeline in tabular format.""" tabular_summary = [] col_names = [key for key in summary.nodes[0].DESCRIPTOR.fields_by_name.keys()] + # Remove the columns `output_spec` and `is_output` as they are available in + # the visualization graph. + col_names.remove("output_spec") + col_names.remove("is_output") # Insert the average processing time column after the max processing time # column. index = col_names.index("max_processing_time_ns") col_names.insert(index + 1, _AVG_PROCESSING_TIME_COLUMN_NAME) + tabular_summary.append( + [_COLUMN_NAME_OVERRIDES.get(name, name) for name in col_names] + ) + # Compute the maximum width of each column. col_widths = [] for name in col_names: @@ -120,24 +138,12 @@ def _pretty_format_summary( max_width = max(len(str(value)), max_width) col_widths.append(max_width) - col_headers = [ - "| {:<{}} |".format(_COLUMN_NAME_OVERRIDES.get(name, name), width) - for name, width in zip(col_names, col_widths) - ] - col_seperators = ["-" * len(header) for header in col_headers] - - tabular_summary.extend(col_seperators) - tabular_summary.append("\n") - tabular_summary.extend(col_headers) - tabular_summary.append("\n") - tabular_summary.extend(col_seperators) - tabular_summary.append("\n") - for node_id in sorted(summary.nodes, reverse=True): - is_total_processing_time_zero = ( - summary.nodes[node_id].total_processing_time_ns == 0 - ) - for name, width in zip(col_names, col_widths): + row_values = [] + for name in col_names: + is_total_processing_time_zero = ( + summary.nodes[node_id].total_processing_time_ns == 0 + ) if name == _AVG_PROCESSING_TIME_COLUMN_NAME: value = _get_avg_processing_time_ns(summary, node_id) else: @@ -154,19 +160,251 @@ def _pretty_format_summary( # produced an element and processing times & num_produced_elements are # not yet meaningful. if is_total_processing_time_zero: - col_value = f"{f'| N/A':<{width+2}} |" + col_value = "N/A" elif name != "num_produced_elements": - col_value = f"{f'| {_pretty_format_ns(value)}':<{width+2}} |" + col_value = _pretty_format_ns(value) else: - col_value = f"{f'| {value}':<{width+2}} |" + col_value = str(value) else: - col_value = "| {:<{}} |".format(str(value), width) - tabular_summary.append(col_value) - tabular_summary.append("\n") + col_value = str(value) + row_values.append(col_value) + tabular_summary.append(row_values) + table = _Table(tabular_summary, col_widths=col_widths) + return table._get_pretty_wrapped_summary() # pylint: disable=protected-access + + +class _Table: + """Table class for pretty printing tabular data.""" + + def __init__( + self, + contents, + *, + col_widths, + col_delim="|", + row_delim="-", + ): + + self.contents = contents + self._max_col_width = _MAX_COLUMN_WIDTH + self.col_delim = col_delim + self.col_widths = col_widths + self._pretty_summary = [] + self.col_header = [] + + # Determine the number of row_delim characters to fill the space used by + # col_delim characters in a column header. + col_delim_space_fill = len(self.col_delim) * (len(self.contents[0]) - 1) + + self.col_header.append(self.col_delim) + for col_width in self.col_widths: + if col_width > self._max_col_width: + col_width = self._max_col_width + self.col_header.append(row_delim * (col_width + 2)) + self.col_header.append(row_delim * (col_delim_space_fill)) + self.col_header.append(self.col_delim + "\n") + self._pretty_summary.extend(self.col_header) + + def _get_pretty_wrapped_summary(self): + """Wraps the table contents within the max column width and max row lines.""" + + for row in self.contents: + max_wrap = (max([len(i) for i in row]) // self._max_col_width) + 1 + max_wrap = min(max_wrap, _MAX_ROW_LINES) + for r in range(max_wrap): + self._pretty_summary.append(self.col_delim) + for index in range(len(row)): + if self.col_widths[index] > self._max_col_width: + wrap = self._max_col_width + else: + wrap = self.col_widths[index] + start = r * self._max_col_width + end = (r + 1) * self._max_col_width + self._pretty_summary.append(" ") + self._pretty_summary.append(row[index][start:end].ljust(wrap)) + self._pretty_summary.append(" ") + self._pretty_summary.append(self.col_delim) + self._pretty_summary.append("\n") + self._pretty_summary.extend(self.col_header) + + return "".join(self._pretty_summary) + + +def _merge_execution_summaries( + aggregated_summary: execution_summary_pb2.ExecutionSummary, + summary_from_worker: execution_summary_pb2.ExecutionSummary, +): + """Merges the execution summary from the worker into the aggregated summary.""" + # we cannot use MergeFrom here because singular fields like + # `max_processing_time_ns` will be overriden. + for node_id in summary_from_worker.nodes: + aggregated_summary.nodes[node_id].id = summary_from_worker.nodes[node_id].id + aggregated_summary.nodes[node_id].name = summary_from_worker.nodes[ + node_id + ].name + aggregated_summary.nodes[node_id].output_spec = summary_from_worker.nodes[ + node_id + ].output_spec + aggregated_summary.nodes[node_id].is_output = summary_from_worker.nodes[ + node_id + ].is_output + aggregated_summary.nodes[node_id].ClearField("inputs") + aggregated_summary.nodes[node_id].inputs.extend( + summary_from_worker.nodes[node_id].inputs + ) + if aggregated_summary.nodes[node_id].min_processing_time_ns == 0: + aggregated_summary.nodes[node_id].min_processing_time_ns = ( + summary_from_worker.nodes[node_id].min_processing_time_ns + ) + else: + aggregated_summary.nodes[node_id].min_processing_time_ns = min( + aggregated_summary.nodes[node_id].min_processing_time_ns, + summary_from_worker.nodes[node_id].min_processing_time_ns, + ) + aggregated_summary.nodes[node_id].max_processing_time_ns = max( + aggregated_summary.nodes[node_id].max_processing_time_ns, + summary_from_worker.nodes[node_id].max_processing_time_ns, + ) + aggregated_summary.nodes[ + node_id + ].total_processing_time_ns += summary_from_worker.nodes[ + node_id + ].total_processing_time_ns + aggregated_summary.nodes[ + node_id + ].num_produced_elements += summary_from_worker.nodes[ + node_id + ].num_produced_elements + return aggregated_summary + + +def _update_execution_summary_in_main_process( + summary_in_main_process: execution_summary_pb2.ExecutionSummary, + summary_from_workers: execution_summary_pb2.ExecutionSummary, +): + """Updates the execution summary in the main process by merging the summary from the workers.""" + num_nodes_in_main_process = len(summary_in_main_process.nodes) + # The pipeline's output node within the workers becomes the input for the root + # node in the main process. Therefore, the node IDs (and thus input IDs) in + # the worker summary should be updated before merging it with the main + # process's summary. + + for node_id in summary_from_workers.nodes: + updated_node_id = node_id + num_nodes_in_main_process + summary_from_workers.nodes[node_id].id = updated_node_id + current_input_ids = summary_from_workers.nodes[node_id].inputs + summary_from_workers.nodes[node_id].ClearField("inputs") + for input_id in current_input_ids: + summary_from_workers.nodes[node_id].inputs.append( + input_id + num_nodes_in_main_process + ) + input_ids = [] + root_node_in_main = None + # Find the root node in the main process to update its inputs. + for node_id in summary_in_main_process.nodes: + if not getattr(summary_in_main_process.nodes[node_id], "inputs"): + root_node_in_main = node_id + for node_id in summary_from_workers.nodes: + # If the node is an output node in the worker summary, it becomes the input + # for the root node in the main process. + if getattr(summary_from_workers.nodes[node_id], "is_output"): + input_ids.append(summary_from_workers.nodes[node_id].id) + summary_from_workers.nodes[node_id].is_output = False + worker_node_id = summary_from_workers.nodes[node_id].id + summary_in_main_process.nodes[worker_node_id].CopyFrom( + summary_from_workers.nodes[node_id] + ) + summary_in_main_process.nodes[root_node_in_main].inputs.extend(input_ids) + return summary_in_main_process + - for seperator in col_seperators: - tabular_summary.append(seperator) - return "".join(tabular_summary) +class WorkerConnection: + """A simplex connection for a worker to send data to the main process. + + Relies on already created fifo. + + Attributes: + send_fifo: Path to a fifo for sending data. Must be opened for reading by + the main process. + """ + + __slots__ = "send_fifo", "_send_fd" + + def __init__(self, send_fifo: str): + self.send_fifo = send_fifo + self._send_fd = -1 + + def open(self) -> None: + """Opens the connection. + + Blocks the caller until the main process opens the connection for read-only. + """ + self._send_fd = os.open(self.send_fifo, os.O_WRONLY) # pylint: disable=protected-access + + def send(self, data: bytes) -> None: + """Sends data to the connection. + + Blocks the caller until the main process reads the data from the sending + fifo. + + Args: + data: bytes to send. + """ + data_len = len(data) + # Make fixed-size length encoding and send it over along with the data. + encoded_len = f"{data_len:#0{_ENCODED_LENGTH_SIZE}x}".encode() + os.write(self._send_fd, encoded_len) + os.write(self._send_fd, data) + + +class MainConnection: + """A simplex connection for the main process to receive data from workers. + + Relies on already created fifo. + + Attributes: + recv_fifo: Path to a fifo for receiving data. Must be opened for writing by + a worker. + """ + + __slots__ = "recv_fifo", "_recv_fd" + + def __init__(self, recv_fifo: str): + self.recv_fifo = recv_fifo + self._recv_fd = -1 + + def open(self) -> None: + """Opens the connection. + + Blocks the caller until the client end of the connection is opened. + """ + self._recv_fd = os.open(self.recv_fifo, os.O_RDONLY) # pylint: disable=protected-access + + def recv(self) -> bytes: + """Reads data from the connection. + + Blocks the caller until the worker end sends data through the receiving + fifo. + + Returns: + bytes read from the connection. + """ + # Reading the fixed number of leading bytes containing hex-encoded message + # length. + input_length = os.read(self._recv_fd, _ENCODED_LENGTH_SIZE) + if not input_length: + # Reached EOF. This means that the client end of the pipe was closed. + return b"" + # Decode the length and read the necessary number of bytes. + input_length = int(input_length, 16) + input_parts = [] + read_length = 0 + while read_length < input_length: + read_buffer_size = min(_READ_BUFFER_SIZE, input_length - read_length) + buffer = os.read(self._recv_fd, read_buffer_size) + input_parts.append(buffer) + read_length += len(buffer) + return b"".join(input_parts) class Timer: @@ -205,7 +443,7 @@ def reset(self): self._last = 0 -@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +@dataclasses.dataclass(slots=True, kw_only=True) class StatsConfig: """Statistics recording condiguration.""" @@ -215,6 +453,9 @@ class StatsConfig: # Whether this transformation mutates the element spec. This is used to # determine element spec of the current transformation. transform_mutates_spec: bool = True + main_stats_connections: tuple[MainConnection, ...] | None = None + worker_stats_fifo: str | None = None + send_stats_to_main_process: bool = False class Stats(abc.ABC): @@ -425,13 +666,65 @@ def _reporting_loop(self): time.sleep(_REPORTING_PERIOD_SEC) self.report() + def _send_summary_to_main_process_loop(self): + """Sends the execution summary to the main process periodically.""" + connection = WorkerConnection(self._config.worker_stats_fifo) + connection.open() + while True: + time.sleep(_LOG_EXECTION_SUMMARY_PERIOD_SEC) + try: + # Blocks until the main process reads the data. + connection.send(self._get_execution_summary().SerializeToString()) + except BrokenPipeError: + break + + def _get_combined_summary_from_workers( + self, + ) -> execution_summary_pb2.ExecutionSummary: + """Returns the aggregated execution summary from all workers.""" + aggregated_summary_from_workers = execution_summary_pb2.ExecutionSummary() + if self._config.main_stats_connections is not None: + for connection in self._config.main_stats_connections: + try: + summary_from_worker = connection.recv() + if not summary_from_worker: + break + summary_from_worker = ( + execution_summary_pb2.ExecutionSummary.FromString( + summary_from_worker + ) + ) + # Combine the summary from all workers into a single summary. + aggregated_summary_from_workers = _merge_execution_summaries( + aggregated_summary_from_workers, summary_from_worker + ) + except (BrokenPipeError, EOFError): + break + except Exception as e: # pylint: disable=broad-except + logging.exception("Failed to deserialize summary from worker: %s", e) + break + return aggregated_summary_from_workers + + def _get_aggregated_execution_summary( + self, + ) -> execution_summary_pb2.ExecutionSummary: + """Returns the aggregated execution summary from all workers.""" + summary_in_main_process = self._get_execution_summary() + summary_from_workers = self._get_combined_summary_from_workers() + # Update the nodes in the main process with the aggregated summary from + # all workers. + aggregated_summary = _update_execution_summary_in_main_process( + summary_in_main_process, summary_from_workers + ) + return aggregated_summary + def _logging_execution_summary_loop(self): + """Logs the aggregated execution summary to the main process periodically.""" while True: time.sleep(_LOG_EXECTION_SUMMARY_PERIOD_SEC) - summary = self._get_execution_summary() logging.info( "Grain Dataset Execution Summary:\n\n%s", - _pretty_format_summary(summary), + _pretty_format_summary(self._get_aggregated_execution_summary()), ) def _build_execution_summary( @@ -488,9 +781,14 @@ def record_self_time(self, offset_ns: int = 0): if self._logging_thread is None: with self._logging_thread_init_lock: if self._logging_thread is None: - self._logging_thread = threading.Thread( - target=self._logging_execution_summary_loop, daemon=True - ) + if self._config.send_stats_to_main_process: + self._logging_thread = threading.Thread( + target=self._send_summary_to_main_process_loop, daemon=True + ) + else: + self._logging_thread = threading.Thread( + target=self._logging_execution_summary_loop, daemon=True + ) self._logging_thread.start() def report(self): diff --git a/grain/_src/python/dataset/stats_test.py b/grain/_src/python/dataset/stats_test.py index 94bc8bce..d4d5cdd1 100644 --- a/grain/_src/python/dataset/stats_test.py +++ b/grain/_src/python/dataset/stats_test.py @@ -100,7 +100,7 @@ "[]" ││ - ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:441) + ││ MapDatasetIterator(transform= @ .../python/dataset/stats_test.py:462) ││ ╲╱ {'data': "[]", diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 7424e7b9..921f7405 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -20,14 +20,17 @@ import contextlib import copy import functools +import os import queue import sys +import tempfile import threading import time import typing from typing import Any, Generic, Mapping, Optional, Protocol, TypeVar import cloudpickle +from grain._src.core import config as grain_config from concurrent import futures from grain._src.core import tree import multiprocessing as mp @@ -357,10 +360,14 @@ class GetElementProducerFn(grain_pool.GetElementProducerFn, Generic[T]): """ def __init__( - self, state: dict[str, dict[str, Any] | int], ds: dataset.IterDataset[T] + self, + state: dict[str, dict[str, Any] | int], + ds: dataset.IterDataset[T], + stats_fifos: list[str], ): self._state = state self._ds = ds + self._stats_fifos = stats_fifos def __call__( self, *, worker_index: int, worker_count: int @@ -407,8 +414,6 @@ def serialize(self) -> bytes: class MultiprocessPrefetchDatasetIterator(dataset.DatasetIterator[T]): """Iterator that performs prefetching using a multiprocessing pool.""" - _MUTATES_ELEMENT_SPEC = False - def __init__( self, parent: dataset.IterDataset[T], @@ -427,6 +432,9 @@ def __init__( # last worker. workers_state: dict[str, Any] = {} iterations_to_skip: dict[str, int] = {} + self._create_main_connection_thread = None + self._tmp_dir = tempfile.mkdtemp() + self._stats_fifos = [] for i in range(multiprocessing_options.num_workers): workers_state[str(i)] = iter( self._iter_parent @@ -484,11 +492,22 @@ def _ensure_iterator_initialized(self) -> None: self._raw_iterator.start_prefetch() self._iterator = _iterator_with_context(self._raw_iterator) + def _create_main_connections(self) -> None: + """Creates a `MainConnection` for each worker to receive data from.""" + if not grain_config.config.py_debug_mode: + return + connections = [] + for stats_fifo in self._stats_fifos: + connection = dataset_stats.MainConnection(stats_fifo) + connection.open() + connections.append(connection) + self._stats._config.main_stats_connections = connections # pylint: disable=protected-access + def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: """Creates a `MultiProcessIterator`.""" get_element_producer_fn = GetElementProducerFn( - self._state, self._iter_parent + self._state, self._iter_parent, self._stats_fifos ) return grain_pool.MultiProcessIterator(