Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Ray iter_batches instead of to_arrow_refs #84

Merged
merged 1 commit into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions python/src/space/ray/ops/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
row_id_field_name)
import space.core.transform.utils as transform_utils
from space.core.utils import errors
from space.ray.ops.utils import singleton_storage
from space.ray.ops.utils import iter_batches, singleton_storage
from space.ray.options import RayOptions

if TYPE_CHECKING:
Expand Down Expand Up @@ -138,14 +138,7 @@ def _join(left: _JoinInputInternal, right: _JoinInputInternal, join_key: str,


def _read_all(ds: ray.data.Dataset) -> Optional[pa.Table]:
results = []
for ref in ds.to_arrow_refs():
data = ray.get(ref)
if data is None or data.num_rows == 0:
continue

results.append(data)

results = list(iter_batches(ds))
if not results:
return None

Expand Down
15 changes: 14 additions & 1 deletion python/src/space/ray/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"""Utilities for Ray operations."""

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Iterator, TYPE_CHECKING

import ray
import pyarrow as pa

from space.core.storage import Storage
from space.core.utils import errors
Expand All @@ -35,3 +38,13 @@ def singleton_storage(view: View) -> Storage:
raise errors.UserInputError("Joining results of joins is not supported")

return list(view.sources.values())[0].storage


def iter_batches(ds: ray.data.Dataset) -> Iterator[pa.Table]:
"""Return an iterator of PyArrow tables from a Ray dataset."""
# batch_size is None to use entire Ray blocks.
for data in ds.iter_batches(batch_size=None,
batch_format="pyarrow",
drop_last=False):
if data.num_rows > 0:
yield data
21 changes: 7 additions & 14 deletions python/src/space/ray/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pyarrow as pa
import pyarrow.compute as pc
import ray

from space.core.jobs import JobResult
from space.core.loaders.array_record import ArrayRecordIndexFn
Expand All @@ -41,7 +40,7 @@
from space.ray.ops.change_data import read_change_data
from space.ray.ops.delete import RayDeleteOp
from space.ray.ops.insert import RayInsertOp
from space.ray.ops.utils import singleton_storage
from space.ray.ops.utils import iter_batches, singleton_storage
from space.ray.options import RayOptions

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,9 +82,8 @@ def read(
read_options = ReadOptions(filter_, fields, snapshot_id, reference_read,
batch_size)

for ref in self._view.ray_dataset(self._ray_options, read_options,
join_options).to_arrow_refs():
yield ray.get(ref)
return iter_batches(
self._view.ray_dataset(self._ray_options, read_options, join_options))

def diff(self,
start_version: Union[Version],
Expand All @@ -102,11 +100,8 @@ def diff(self,
# TODO: skip processing the data for deletions; the caller is usually
# only interested at deleted primary keys.
# TODO: to split change data into chunks for parallel processing.
processed_remote_data = self._view.process_source(change.data)
for ref in processed_remote_data.to_arrow_refs():
data = ray.get(ref)
if data.num_rows > 0:
yield ChangeData(change.snapshot_id, change.type_, data)
for data in iter_batches(self._view.process_source(change.data)):
yield ChangeData(change.snapshot_id, change.type_, data)

@property
def _source_storage(self) -> Storage:
Expand Down Expand Up @@ -150,10 +145,8 @@ def read(
self._storage.version_to_snapshot_id(version))
read_options = ReadOptions(filter_, fields, snapshot_id, reference_read,
batch_size)

for ref in self._storage.ray_dataset(self._ray_options,
read_options).to_arrow_refs():
yield ray.get(ref)
return iter_batches(
self._storage.ray_dataset(self._ray_options, read_options))

def refresh(self,
target_version: Optional[Version] = None,
Expand Down
Loading