Skip to content

Commit

Permalink
[Datasets] Streaming executor fixes ray-project#3 (ray-project#32836)
Browse files Browse the repository at this point in the history
Signed-off-by: Jack He <jackhe2345@gmail.com>
  • Loading branch information
jianoaix authored and ProjectsByJackHe committed Mar 21, 2023
1 parent 09ae0ba commit 916e318
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@
- sudo service mongodb stop
- sudo apt-get purge -y mongodb*

- label: "[unstable] Dataset tests (streaming executor)"
- label: "Dataset tests (streaming executor)"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_DATA_AFFECTED"]
instance_size: medium
commands:
Expand Down
5 changes: 4 additions & 1 deletion python/ray/data/_internal/execution/legacy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _blocks_to_input_buffer(blocks: BlockList, owns_blocks: bool) -> PhysicalOpe

if hasattr(blocks, "_tasks"):
read_tasks = blocks._tasks
remote_args = blocks._remote_args
assert all(isinstance(t, ReadTask) for t in read_tasks), read_tasks
inputs = InputDataBuffer(
[
Expand Down Expand Up @@ -157,7 +158,9 @@ def do_read(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
for read_task in blocks:
yield from read_task()

return MapOperator.create(do_read, inputs, name="DoRead")
return MapOperator.create(
do_read, inputs, name="DoRead", ray_remote_args=remote_args
)
else:
output = _block_list_to_bundles(blocks, owns_blocks=owns_blocks)
for i in output:
Expand Down
6 changes: 5 additions & 1 deletion python/ray/data/_internal/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ def clear_block_refs(self) -> None:
This will render the plan un-executable unless the root is a LazyBlockList."""
self._in_blocks.clear()
self._clear_snapshot()

def _clear_snapshot(self) -> None:
"""Clear the snapshot kept in the plan to the beginning state."""
self._snapshot_blocks = None
self._snapshot_stats = None
# We're erasing the snapshot, so put all stages into the "after snapshot"
Expand Down Expand Up @@ -691,7 +695,7 @@ def _get_source_blocks_and_stages(
stats = self._snapshot_stats
# Unlink the snapshot blocks from the plan so we can eagerly reclaim the
# snapshot block memory after the first stage is done executing.
self._snapshot_blocks = None
self._clear_snapshot()
else:
# Snapshot exists but has been cleared, so we need to recompute from the
# source (input blocks).
Expand Down
9 changes: 8 additions & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,14 @@ def test_convert_types(ray_start_regular_shared):

arrow_ds = ray.data.range_table(1)
assert arrow_ds.map(lambda x: "plain_{}".format(x["value"])).take() == ["plain_0"]
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": [0]}]
# In streaming, we set batch_format to "default" (because calling
# ds.dataset_format() will still invoke bulk execution and we want
# to avoid that). As a result, it's receiving PandasRow (the defaut
# batch format), which unwraps [0] to plain 0.
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": 0}]
else:
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": [0]}]


def test_from_items(ray_start_regular_shared):
Expand Down
22 changes: 19 additions & 3 deletions python/ray/data/tests/test_dataset_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,29 @@ def test_readback_tfrecords(ray_start_regular_shared, tmp_path):
# for type inference involving partially missing columns.
parallelism=1,
)

# Write the TFRecords.
ds.write_tfrecords(tmp_path)

# Read the TFRecords.
readback_ds = ray.data.read_tfrecords(tmp_path)
assert ds.take() == readback_ds.take()
if not ray.data.context.DatasetContext.get_current().use_streaming_executor:
assert ds.take() == readback_ds.take()
else:
# In streaming, we set batch_format to "default" (because calling
# ds.dataset_format() will still invoke bulk execution and we want
# to avoid that). As a result, it's receiving PandasRow (the defaut
# batch format), which doesn't have the same ordering of columns as
# the ArrowRow.
from ray.data.block import BlockAccessor

def get_rows(ds):
rows = []
for batch in ds.iter_batches(batch_size=None, batch_format="pyarrow"):
batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch))
for row in batch.iter_rows():
rows.append(row)
return rows

assert get_rows(ds) == get_rows(readback_ds)


def test_write_invalid_tfrecords(ray_start_regular_shared, tmp_path):
Expand Down
53 changes: 43 additions & 10 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,18 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
context.optimize_fuse_stages = True

if context.new_execution_backend:
logger = DatasetLogger("ray.data._internal.execution.bulk_executor").get_logger(
log_to_stdout=enable_auto_log_stats,
)
if context.use_streaming_executor:
logger = DatasetLogger(
"ray.data._internal.execution.streaming_executor"
).get_logger(
log_to_stdout=enable_auto_log_stats,
)
else:
logger = DatasetLogger(
"ray.data._internal.execution.bulk_executor"
).get_logger(
log_to_stdout=enable_auto_log_stats,
)
else:
logger = DatasetLogger("ray.data._internal.plan").get_logger(
log_to_stdout=enable_auto_log_stats,
Expand Down Expand Up @@ -111,9 +120,24 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
stats = canonicalize(ds.fully_executed().stats())

if context.new_execution_backend:
assert (
stats
== """Stage N read->MapBatches(dummy_map_batches): N/N blocks executed in T
if context.use_streaming_executor:
assert (
stats
== """Stage N read->MapBatches(dummy_map_batches)->map: N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
* Output num rows: N min, N max, N mean, N total
* Output size bytes: N min, N max, N mean, N total
* Tasks per node: N min, N max, N mean; N nodes used
* Extra metrics: {'obj_store_mem_alloc': N, 'obj_store_mem_freed': N, \
'obj_store_mem_peak': N}
"""
)
else:
assert (
stats
== """Stage N read->MapBatches(dummy_map_batches): N/N blocks executed in T
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): N min, N max, N mean
Expand Down Expand Up @@ -141,7 +165,7 @@ def test_dataset_stats_basic(ray_start_regular_shared, enable_auto_log_stats):
* In user code: T
* Total time: T
"""
)
)
else:
assert (
stats
Expand Down Expand Up @@ -364,9 +388,18 @@ def test_dataset_pipeline_stats_basic(ray_start_regular_shared, enable_auto_log_
context.optimize_fuse_stages = True

if context.new_execution_backend:
logger = DatasetLogger("ray.data._internal.execution.bulk_executor").get_logger(
log_to_stdout=enable_auto_log_stats,
)
if context.use_streaming_executor:
logger = DatasetLogger(
"ray.data._internal.execution.streaming_executor"
).get_logger(
log_to_stdout=enable_auto_log_stats,
)
else:
logger = DatasetLogger(
"ray.data._internal.execution.bulk_executor"
).get_logger(
log_to_stdout=enable_auto_log_stats,
)
else:
logger = DatasetLogger("ray.data._internal.plan").get_logger(
log_to_stdout=enable_auto_log_stats,
Expand Down

0 comments on commit 916e318

Please sign in to comment.