Skip to content

Commit

Permalink
[Data] Update Dataset.count() to avoid unnecessarily keeping `Block…
Browse files Browse the repository at this point in the history
…Ref`s in-memory (#46369)

Currently, the implementation of `Dataset.count()` retrieves the entire
list of `BlockRef`s associated with the Dataset when calculating the
number of rows per block. This PR is a minor performance improvement to
use an iterator over the `BlockRef`s, so that we can drop them as soon
as we get each block's row count, and we do not need to hold the entire
list of `BlockRef`s.

Signed-off-by: sjl <sjl@anyscale.com>
Signed-off-by: Scott Lee <sjl@anyscale.com>
  • Loading branch information
scottjlee authored Jul 10, 2024
1 parent c103330 commit f8ee70a
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 22 deletions.
1 change: 1 addition & 0 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Inspecting Metadata
Dataset.input_files
Dataset.stats
Dataset.get_internal_block_refs
Dataset.iter_internal_ref_bundles

Execution
---------
Expand Down
49 changes: 48 additions & 1 deletion python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
It should be deleted once we fully move to the new executor backend.
"""

from typing import Iterator, Tuple
from typing import Iterator, Optional, Tuple

from ray.data._internal.block_list import BlockList
from ray.data._internal.execution.interfaces import (
Executor,
PhysicalOperator,
RefBundle,
)
from ray.data._internal.execution.interfaces.executor import OutputIterator
from ray.data._internal.logical.optimizers import get_execution_plan
from ray.data._internal.logical.util import record_operators_usage
from ray.data._internal.plan import ExecutionPlan
from ray.data._internal.stats import DatasetStats
from ray.data._internal.util import unify_block_metadata_schema
from ray.data.block import Block, BlockMetadata
from ray.types import ObjectRef

Expand Down Expand Up @@ -59,6 +61,51 @@ def execute_to_legacy_bundle_iterator(
dag = dag_rewrite(dag)

bundle_iter = executor.execute(dag, initial_stats=stats)

class CacheMetadataIterator(OutputIterator):
"""Wrapper for `bundle_iterator` above.
For a given iterator which yields output RefBundles,
collect the metadata from each output bundle, and yield the
original RefBundle. Only after the entire iterator is exhausted,
we cache the resulting metadata to the execution plan."""

def __init__(self, base_iterator: OutputIterator):
# Note: the base_iterator should be of type StreamIterator,
# defined within `StreamingExecutor.execute()`. It must
# support the `get_next()` method.
self._base_iterator = base_iterator
self._collected_metadata = BlockMetadata(
num_rows=0,
size_bytes=0,
schema=None,
input_files=None,
exec_stats=None,
)

def get_next(self, output_split_idx: Optional[int] = None) -> RefBundle:
try:
bundle = self._base_iterator.get_next(output_split_idx)
self._collect_metadata(bundle)
return bundle
except StopIteration:
# Once the iterator is completely exhausted, we are done
# collecting metadata. We can add this cached metadata to the plan.
plan._snapshot_metadata = self._collected_metadata
raise

def _collect_metadata(self, bundle: RefBundle) -> RefBundle:
"""Collect the metadata from each output bundle and accumulate
results, so we can access important information, such as
row count, schema, etc., after iteration completes."""
self._collected_metadata.num_rows += bundle.num_rows()
self._collected_metadata.size_bytes += bundle.size_bytes()
self._collected_metadata.schema = unify_block_metadata_schema(
[self._collected_metadata, *bundle.metadata]
)
return bundle

bundle_iter = CacheMetadataIterator(bundle_iter)
return bundle_iter


Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ def __init__(
self._snapshot_operator: Optional[LogicalOperator] = None
self._snapshot_stats = None
self._snapshot_bundle = None
# Snapshot of only metadata corresponding to the final operator's
# output bundles, used as the source of truth for the Dataset's schema
# and count. This is calculated and cached when the plan is executed as an
# iterator (`execute_to_iterator()`), and avoids caching
# all of the output blocks in memory like in `self.snapshot_bundle`.
# TODO(scottjlee): To keep the caching logic consistent, update `execute()`
# to also store the metadata in `_snapshot_metadata` instead of
# `_snapshot_bundle`. For example, we could store the blocks in
# `self._snapshot_blocks` and the metadata in `self._snapshot_metadata`.
self._snapshot_metadata: Optional[BlockMetadata] = None

# Cached schema.
self._schema = None
Expand Down Expand Up @@ -148,6 +158,9 @@ def generate_logical_plan_string(
# This plan has executed some but not all operators.
schema = unify_block_metadata_schema(self._snapshot_bundle.metadata)
count = self._snapshot_bundle.num_rows()
elif self._snapshot_metadata is not None:
schema = self._snapshot_metadata.schema
count = self._snapshot_metadata.num_rows
else:
# This plan hasn't executed any operators.
sources = self._logical_plan.sources()
Expand Down Expand Up @@ -414,6 +427,8 @@ def execute_to_iterator(

metrics_tag = create_dataset_tag(self._dataset_name, self._dataset_uuid)
executor = StreamingExecutor(copy.deepcopy(ctx.execution_options), metrics_tag)
# TODO(scottjlee): replace with `execute_to_legacy_bundle_iterator` and
# update execute_to_iterator usages to handle RefBundles instead of Blocks
block_iter = execute_to_legacy_block_iterator(
executor,
self,
Expand Down
78 changes: 57 additions & 21 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
Mapping,
Expand Down Expand Up @@ -77,6 +78,7 @@
VALID_BATCH_FORMATS,
Block,
BlockAccessor,
BlockMetadata,
DataBatch,
T,
U,
Expand Down Expand Up @@ -2459,12 +2461,14 @@ def show(self, limit: int = 20) -> None:
@ConsumptionAPI(
if_more_than_read=True,
datasource_metadata="row count",
pattern="Time complexity:",
pattern="Examples:",
)
def count(self) -> int:
"""Count the number of records in the dataset.
"""Count the number of rows in the dataset.
Time complexity: O(dataset size / parallelism), O(1) for parquet
For Datasets which only read Parquet files (created with
:meth:`~ray.data.read_parquet`), this method reads the file metadata to
efficiently count the number of rows without reading in the entire data.
Examples:
>>> import ray
Expand All @@ -2484,13 +2488,15 @@ def count(self) -> int:
if meta_count is not None:
return meta_count

get_num_rows = cached_remote_fn(_get_num_rows)

return sum(
ray.get(
[get_num_rows.remote(block) for block in self.get_internal_block_refs()]
)
)
# Directly loop over the iterator of `RefBundle`s instead of
# retrieving a full list of `BlockRef`s.
total_rows = 0
for ref_bundle in self.iter_internal_ref_bundles():
num_rows = ref_bundle.num_rows()
# Executing the dataset always returns blocks with valid `num_rows`.
assert num_rows is not None
total_rows += num_rows
return total_rows

@ConsumptionAPI(
if_more_than_read=True,
Expand Down Expand Up @@ -4328,14 +4334,15 @@ def to_pandas(self, limit: int = None) -> "pandas.DataFrame":
ValueError: if the number of rows in the :class:`~ray.data.Dataset` exceeds
``limit``.
"""
count = self.count()
if limit is not None and count > limit:
raise ValueError(
f"the dataset has more than the given limit of {limit} "
f"rows: {count}. If you are sure that a DataFrame with "
f"{count} rows will fit in local memory, set ds.to_pandas(limit=None) "
"to disable limits."
)
if limit is not None:
count = self.count()
if count > limit:
raise ValueError(
f"the dataset has more than the given limit of {limit} "
f"rows: {count}. If you are sure that a DataFrame with "
f"{count} rows will fit in local memory, set "
"ds.to_pandas(limit=None) to disable limits."
)
blocks = self.get_internal_block_refs()
output = DelegatingBlockBuilder()
for block in blocks:
Expand Down Expand Up @@ -4563,7 +4570,36 @@ def stats(self) -> str:
def _get_stats_summary(self) -> DatasetStatsSummary:
return self._plan.stats_summary()

@ConsumptionAPI(pattern="Time complexity:")
@ConsumptionAPI(pattern="Examples:")
@DeveloperAPI
def iter_internal_ref_bundles(self) -> Iterator[RefBundle]:
"""Get an iterator over ``RefBundles``
belonging to this Dataset. Calling this function doesn't keep
the data materialized in-memory.
Examples:
>>> import ray
>>> ds = ray.data.range(1)
>>> for ref_bundle in ds.iter_internal_ref_bundles():
... for block_ref, block_md in ref_bundle.blocks:
... block = ray.get(block_ref)
Returns:
An iterator over this Dataset's ``RefBundles``.
"""

def _build_ref_bundles(
iter_blocks: Iterator[Tuple[ObjectRef[Block], BlockMetadata]],
) -> Iterator[RefBundle]:
for block in iter_blocks:
yield RefBundle((block,), owns_blocks=True)

iter_block_refs_md, _, _ = self._plan.execute_to_iterator()
iter_ref_bundles = _build_ref_bundles(iter_block_refs_md)
self._synchronize_progress_bar()
return iter_ref_bundles

@ConsumptionAPI(pattern="Examples:")
@DeveloperAPI
def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
"""Get a list of references to the underlying blocks of this dataset.
Expand All @@ -4577,11 +4613,11 @@ def get_internal_block_refs(self) -> List[ObjectRef[Block]]:
>>> ds.get_internal_block_refs()
[ObjectRef(...)]
Time complexity: O(1)
Returns:
A list of references to this dataset's blocks.
"""
# TODO(scottjlee): replace get_internal_block_refs() usages with
# iter_internal_ref_bundles()
block_refs = self._plan.execute().block_refs
self._synchronize_progress_bar()
return block_refs
Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,21 @@ def test_count_edge_case(ray_start_regular):
assert actual_count == 5


def test_count_after_partial_execution(ray_start_regular):
paths = ["example://iris.csv"] * 5
ds = ray.data.read_csv(paths)
for batch in ds.iter_batches():
# Take one batch and break to simulate partial iteration/execution.
break
# Row count should be unknown after partial execution.
assert "num_rows=?" in str(ds)

# After iterating over bundles and completing execution, row count should be known.
list(ds.iter_internal_ref_bundles())
assert f"num_rows={150*5}" in str(ds)
assert ds.count() == 150 * 5


def test_limit_execution(ray_start_regular):
last_snapshot = get_initial_core_execution_metrics_snapshot()
override_num_blocks = 20
Expand Down
17 changes: 17 additions & 0 deletions python/ray/data/tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ def test_get_internal_block_refs(ray_start_regular_shared):
assert out == list(range(10)), out


def test_iter_internal_ref_bundles(ray_start_regular_shared):
n = 10
ds = ray.data.range(n, override_num_blocks=n)
iter_ref_bundles = ds.iter_internal_ref_bundles()

out = []
ref_bundle_count = 0
for ref_bundle in iter_ref_bundles:
for block_ref, block_md in ref_bundle.blocks:
b = ray.get(block_ref)
out.extend(extract_values("id", BlockAccessor.for_block(b).iter_rows(True)))
ref_bundle_count += 1
out = sorted(out)
assert ref_bundle_count == n
assert out == list(range(n)), out


def test_fsspec_filesystem(ray_start_regular_shared, tmp_path):
"""Same as `test_parquet_write` but using a custom, fsspec filesystem.
Expand Down

0 comments on commit f8ee70a

Please sign in to comment.