diff --git a/python/src/space/ray/ops/join.py b/python/src/space/ray/ops/join.py index afc248d..867fe64 100644 --- a/python/src/space/ray/ops/join.py +++ b/python/src/space/ray/ops/join.py @@ -29,7 +29,7 @@ row_id_field_name) import space.core.transform.utils as transform_utils from space.core.utils import errors -from space.ray.ops.utils import singleton_storage +from space.ray.ops.utils import iter_batches, singleton_storage from space.ray.options import RayOptions if TYPE_CHECKING: @@ -138,14 +138,7 @@ def _join(left: _JoinInputInternal, right: _JoinInputInternal, join_key: str, def _read_all(ds: ray.data.Dataset) -> Optional[pa.Table]: - results = [] - for ref in ds.to_arrow_refs(): - data = ray.get(ref) - if data is None or data.num_rows == 0: - continue - - results.append(data) - + results = list(iter_batches(ds)) if not results: return None diff --git a/python/src/space/ray/ops/utils.py b/python/src/space/ray/ops/utils.py index 2938d5c..745ed32 100644 --- a/python/src/space/ray/ops/utils.py +++ b/python/src/space/ray/ops/utils.py @@ -15,7 +15,10 @@ """Utilities for Ray operations.""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Iterator, TYPE_CHECKING + +import ray +import pyarrow as pa from space.core.storage import Storage from space.core.utils import errors @@ -35,3 +38,13 @@ def singleton_storage(view: View) -> Storage: raise errors.UserInputError("Joining results of joins is not supported") return list(view.sources.values())[0].storage + + +def iter_batches(ds: ray.data.Dataset) -> Iterator[pa.Table]: + """Return an iterator of PyArrow tables from a Ray dataset.""" + # batch_size is None to use entire Ray blocks. + for data in ds.iter_batches(batch_size=None, + batch_format="pyarrow", + drop_last=False): + if data.num_rows > 0: + yield data diff --git a/python/src/space/ray/runners.py b/python/src/space/ray/runners.py index 2650e20..b5aecc6 100644 --- a/python/src/space/ray/runners.py +++ b/python/src/space/ray/runners.py @@ -21,7 +21,6 @@ import pyarrow as pa import pyarrow.compute as pc -import ray from space.core.jobs import JobResult from space.core.loaders.array_record import ArrayRecordIndexFn @@ -41,7 +40,7 @@ from space.ray.ops.change_data import read_change_data from space.ray.ops.delete import RayDeleteOp from space.ray.ops.insert import RayInsertOp -from space.ray.ops.utils import singleton_storage +from space.ray.ops.utils import iter_batches, singleton_storage from space.ray.options import RayOptions if TYPE_CHECKING: @@ -83,9 +82,8 @@ def read( read_options = ReadOptions(filter_, fields, snapshot_id, reference_read, batch_size) - for ref in self._view.ray_dataset(self._ray_options, read_options, - join_options).to_arrow_refs(): - yield ray.get(ref) + return iter_batches( + self._view.ray_dataset(self._ray_options, read_options, join_options)) def diff(self, start_version: Union[Version], @@ -102,11 +100,8 @@ def diff(self, # TODO: skip processing the data for deletions; the caller is usually # only interested at deleted primary keys. # TODO: to split change data into chunks for parallel processing. - processed_remote_data = self._view.process_source(change.data) - for ref in processed_remote_data.to_arrow_refs(): - data = ray.get(ref) - if data.num_rows > 0: - yield ChangeData(change.snapshot_id, change.type_, data) + for data in iter_batches(self._view.process_source(change.data)): + yield ChangeData(change.snapshot_id, change.type_, data) @property def _source_storage(self) -> Storage: @@ -150,10 +145,8 @@ def read( self._storage.version_to_snapshot_id(version)) read_options = ReadOptions(filter_, fields, snapshot_id, reference_read, batch_size) - - for ref in self._storage.ray_dataset(self._ray_options, - read_options).to_arrow_refs(): - yield ray.get(ref) + return iter_batches( + self._storage.ray_dataset(self._ray_options, read_options)) def refresh(self, target_version: Optional[Version] = None,