Skip to content

Commit

Permalink
Use Ray iter_batches instead of to_arrow_refs
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Jan 28, 2024
1 parent 8fc5ca5 commit a3f7887
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
11 changes: 2 additions & 9 deletions python/src/space/ray/ops/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
row_id_field_name)
import space.core.transform.utils as transform_utils
from space.core.utils import errors
from space.ray.ops.utils import singleton_storage
from space.ray.ops.utils import iter_batches, singleton_storage
from space.ray.options import RayOptions

if TYPE_CHECKING:
Expand Down Expand Up @@ -138,14 +138,7 @@ def _join(left: _JoinInputInternal, right: _JoinInputInternal, join_key: str,


def _read_all(ds: ray.data.Dataset) -> Optional[pa.Table]:
results = []
for ref in ds.to_arrow_refs():
data = ray.get(ref)
if data is None or data.num_rows == 0:
continue

results.append(data)

results = list(iter_batches(ds))
if not results:
return None

Expand Down
15 changes: 14 additions & 1 deletion python/src/space/ray/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
"""Utilities for Ray operations."""

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Iterator, TYPE_CHECKING

import ray
import pyarrow as pa

from space.core.storage import Storage
from space.core.utils import errors
Expand All @@ -35,3 +38,13 @@ def singleton_storage(view: View) -> Storage:
raise errors.UserInputError("Joining results of joins is not supported")

return list(view.sources.values())[0].storage


def iter_batches(ds: ray.data.Dataset) -> Iterator[pa.Table]:
"""Return an iterator of PyArrow tables from a Ray dataset."""
# batch_size is None to use entire Ray blocks.
for data in ds.iter_batches(batch_size=None,
batch_format="pyarrow",
drop_last=False):
if data.num_rows > 0:
yield data
21 changes: 7 additions & 14 deletions python/src/space/ray/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import pyarrow as pa
import pyarrow.compute as pc
import ray

from space.core.jobs import JobResult
from space.core.loaders.array_record import ArrayRecordIndexFn
Expand All @@ -41,7 +40,7 @@
from space.ray.ops.change_data import read_change_data
from space.ray.ops.delete import RayDeleteOp
from space.ray.ops.insert import RayInsertOp
from space.ray.ops.utils import singleton_storage
from space.ray.ops.utils import iter_batches, singleton_storage
from space.ray.options import RayOptions

if TYPE_CHECKING:
Expand Down Expand Up @@ -83,9 +82,8 @@ def read(
read_options = ReadOptions(filter_, fields, snapshot_id, reference_read,
batch_size)

for ref in self._view.ray_dataset(self._ray_options, read_options,
join_options).to_arrow_refs():
yield ray.get(ref)
return iter_batches(
self._view.ray_dataset(self._ray_options, read_options, join_options))

def diff(self,
start_version: Union[Version],
Expand All @@ -102,11 +100,8 @@ def diff(self,
# TODO: skip processing the data for deletions; the caller is usually
# only interested at deleted primary keys.
# TODO: to split change data into chunks for parallel processing.
processed_remote_data = self._view.process_source(change.data)
for ref in processed_remote_data.to_arrow_refs():
data = ray.get(ref)
if data.num_rows > 0:
yield ChangeData(change.snapshot_id, change.type_, data)
for data in iter_batches(self._view.process_source(change.data)):
yield ChangeData(change.snapshot_id, change.type_, data)

@property
def _source_storage(self) -> Storage:
Expand Down Expand Up @@ -150,10 +145,8 @@ def read(
self._storage.version_to_snapshot_id(version))
read_options = ReadOptions(filter_, fields, snapshot_id, reference_read,
batch_size)

for ref in self._storage.ray_dataset(self._ray_options,
read_options).to_arrow_refs():
yield ray.get(ref)
return iter_batches(
self._storage.ray_dataset(self._ray_options, read_options))

def refresh(self,
target_version: Optional[Version] = None,
Expand Down

0 comments on commit a3f7887

Please sign in to comment.