Skip to content

Commit

Permalink
Enforce read batch size when refreshing materialized views (#82)
Browse files Browse the repository at this point in the history
* Enforce read batch size when refreshing materialized views

* update doc
  • Loading branch information
Zhou Fang authored Jan 26, 2024
1 parent 87681f4 commit de825aa
Show file tree
Hide file tree
Showing 14 changed files with 173 additions and 67 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ mv = view.materialize("/path/to/<mybucket>/example_mv")

mv_runner = mv.ray()
# Refresh the MV up to version tag `after_add` of the source.
mv_runner.refresh("after_add") # mv_runner.refresh() refresh to the latest version
mv_runner.refresh("after_add", batch_size=64) # Reading batch size
# Or, mv_runner.refresh() refresh to the latest version

# Use the MV runner instead of view runner to directly read from materialized
# view files, no data processing any more.
Expand Down
2 changes: 1 addition & 1 deletion python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Unify data in your entire machine learning lifecycle with **Space**, a comprehen

**Key Features:**
- **Ground Truth Database**
- Store and manage data in open source file formats, locally or in the cloud.
- Store and manage multimodal data in open source file formats, row or columnar, local or in cloud.
- Ingest from various sources, including ML datasets, files, and labeling tools.
- Support data manipulation (append, insert, update, delete) and version control.
- **OLAP Database and Lakehouse**
Expand Down
49 changes: 24 additions & 25 deletions python/src/space/core/ops/change_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing import Iterator, List

import pyarrow as pa
from pyroaring import BitMap # type: ignore[import-not-found]

from space.core.fs.base import BaseFileSystem
from space.core.fs.factory import create_fs
from space.core.ops.read import FileSetReadOp
from space.core.options import ReadOptions
import space.core.proto.metadata_pb2 as meta
import space.core.proto.runtime_pb2 as rt
from space.core.storage import Storage
Expand All @@ -35,6 +35,7 @@ class ChangeType(Enum):
"""Type of data changes."""
# For added rows.
ADD = 1

# For deleted rows.
DELETE = 2
# TODO: to support UPDATE. UPDATE is currently described as an ADD after a
Expand All @@ -46,14 +47,17 @@ class ChangeData:
"""Information and data of a change."""
# Snapshot ID that the change was committed to.
snapshot_id: int

# The change type.
type_: ChangeType

# The change data.
data: pa.Table


def read_change_data(storage: Storage, start_snapshot_id: int,
end_snapshot_id: int) -> Iterator[ChangeData]:
end_snapshot_id: int,
read_options: ReadOptions) -> Iterator[ChangeData]:
"""Read change data from a start to an end snapshot.
start_snapshot_id is excluded; end_snapshot_id is included.
Expand All @@ -78,19 +82,22 @@ def read_change_data(storage: Storage, start_snapshot_id: int,
f"snapshot {end_snapshot_id}")

for snapshot_id in all_snapshot_ids[1:]:
for result in iter(_LocalChangeDataReadOp(storage, snapshot_id)):
for result in iter(
_LocalChangeDataReadOp(storage, snapshot_id, read_options)):
yield result


class _LocalChangeDataReadOp(StoragePathsMixin):
"""Read changes of data from a given snapshot of a dataset."""

def __init__(self, storage: Storage, snapshot_id: int):
def __init__(self, storage: Storage, snapshot_id: int,
read_options: ReadOptions):
StoragePathsMixin.__init__(self, storage.location)

self._storage = storage
self._metadata = self._storage.metadata
self._snapshot_id = snapshot_id
self._read_options = read_options

if snapshot_id not in self._metadata.snapshots:
raise errors.VersionNotFoundError(
Expand All @@ -108,35 +115,27 @@ def __iter__(self) -> Iterator[ChangeData]:
# TODO: to enforce this check upstream, or merge deletion+addition as a
# update.
for bitmap in self._change_log.deleted_rows:
yield self._read_bitmap_rows(ChangeType.DELETE, bitmap)
for change in self._read_bitmap_rows(ChangeType.DELETE, bitmap):
yield change

for bitmap in self._change_log.added_rows:
yield self._read_bitmap_rows(ChangeType.ADD, bitmap)
for change in self._read_bitmap_rows(ChangeType.ADD, bitmap):
yield change

def _read_bitmap_rows(self, change_type: ChangeType,
bitmap: meta.RowBitmap) -> ChangeData:
bitmap: meta.RowBitmap) -> Iterator[ChangeData]:
file_set = rt.FileSet(index_files=[rt.DataFile(path=bitmap.file)])
read_op = FileSetReadOp(self._storage.location, self._metadata, file_set)

data = pa.concat_tables(list(iter(read_op)))
# TODO: to read index fields first, apply mask, then read record fields.
if not bitmap.all_rows:
data = data.filter(
mask=_bitmap_mask(bitmap.roaring_bitmap, data.num_rows))
read_op = FileSetReadOp(
self._storage.location,
self._metadata,
file_set,
row_bitmap=(None if bitmap.all_rows else bitmap.roaring_bitmap),
options=self._read_options)

return ChangeData(self._snapshot_id, change_type, data)
for data in read_op:
yield ChangeData(self._snapshot_id, change_type, data)


def _read_change_log_proto(fs: BaseFileSystem,
file_path: str) -> meta.ChangeLog:
return fs.read_proto(file_path, meta.ChangeLog())


def _bitmap_mask(serialized_bitmap: bytes, num_rows: int) -> List[bool]:
bitmap = BitMap.deserialize(serialized_bitmap)

mask = [False] * num_rows
for row_id in bitmap.to_array():
mask[row_id] = True

return mask
6 changes: 4 additions & 2 deletions python/src/space/core/ops/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def filter_matched(location: str, metadata: meta.StorageMetadata,
data_files: rt.FileSet, filter_: pc.Expression,
primary_keys: List[str]) -> bool:
"""Return True if there are data matching the provided filter."""
op = FileSetReadOp(location, metadata, data_files,
ReadOptions(filter_=filter_, fields=primary_keys))
op = FileSetReadOp(location,
metadata,
data_files,
options=ReadOptions(filter_=filter_, fields=primary_keys))

for data in iter(op):
if data.num_rows > 0:
Expand Down
25 changes: 25 additions & 0 deletions python/src/space/core/ops/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from pyroaring import BitMap # type: ignore[import-not-found]

from space.core.options import ReadOptions
from space.core.fs.array_record import read_record_file
Expand All @@ -30,6 +31,7 @@
from space.core.schema import arrow
from space.core.schema.constants import FILE_PATH_FIELD, ROW_ID_FIELD
from space.core.schema import utils as schema_utils
from space.core.utils import errors
from space.core.utils.paths import StoragePathsMixin

_RECORD_KEY_FIELD = "__RECORD_KEY"
Expand All @@ -52,17 +54,20 @@ class FileSetReadOp(BaseReadOp, StoragePathsMixin):
Not thread safe.
"""

# pylint: disable=too-many-arguments
def __init__(self,
location: str,
metadata: meta.StorageMetadata,
file_set: rt.FileSet,
row_bitmap: Optional[bytes] = None,
options: Optional[ReadOptions] = None):
StoragePathsMixin.__init__(self, location)

# TODO: to validate that filter_ does not contain record fields.

self._metadata = metadata
self._file_set = file_set
self._row_bitmap = row_bitmap

# TODO: to validate options, e.g., fields are valid.
self._options = options or ReadOptions()
Expand Down Expand Up @@ -92,13 +97,23 @@ def __iter__(self) -> Iterator[pa.Table]:
for file in self._file_set.index_files:
row_range_read = file.selected_rows.end > 0

# row_range_read is used by Ray SpaceDataSource. row_bitmap is used by Ray
# diff/refresh, which does not use Ray SpaceDataSource.
if row_range_read and self._row_bitmap is not None:
raise errors.SpaceRuntimeError(
"Row mask is not supported when row range read is enabled")

# TODO: always loading the whole table is inefficient, to only load the
# required row groups.
index_data = pq.read_table(
self.full_path(file.path),
columns=self._selected_fields,
filters=self._options.filter_) # type: ignore[arg-type]

if self._row_bitmap is not None:
index_data = index_data.filter(
mask=_bitmap_mask(self._row_bitmap, index_data.num_rows))

if row_range_read:
length = file.selected_rows.end - file.selected_rows.start
index_data = index_data.slice(file.selected_rows.start, length)
Expand Down Expand Up @@ -197,3 +212,13 @@ def read_record_column(paths: StoragePathsMixin,
sorted_values[key.as_py()] = value

return pa.array(sorted_values, pa.binary()) # type: ignore[return-value]


def _bitmap_mask(serialized_bitmap: bytes, num_rows: int) -> List[bool]:
bitmap = BitMap.deserialize(serialized_bitmap)

mask = [False] * num_rows
for row_id in bitmap.to_array():
mask[row_id] = True

return mask
5 changes: 4 additions & 1 deletion python/src/space/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class ReadOptions:
#
# TODO: currently a batch can be smaller than batch_size (e.g., at boundary
# of row groups), to enforce size to be equal to batch_size.
batch_size: Optional[int] = 16
batch_size: Optional[int] = None

def __post_init__(self):
self.batch_size = self.batch_size or 16


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions python/src/space/core/proto/metadata.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Proto messages used by Space metadata persistence.

syntax = "proto3";

package space.proto;
Expand Down
10 changes: 10 additions & 0 deletions python/src/space/core/proto/runtime.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Proto messages used by Space runtime.
//
// Different from metadata.proto, protos here are not persisted in metadata
// files. We use proto instead of Python classes for the capabilities of
// serialization to bytes for cross machines/languages messaging. For example,
// `FileSet` is sent to worker machine for processing, and `Patch` is sent back
// for the coordinator machine to commit to storage. Pickling Python classses
// may work but it may have more restrictions, especially when crossing
// languages.

syntax = "proto3";

import "space/core/proto/metadata.proto";
Expand Down
18 changes: 11 additions & 7 deletions python/src/space/core/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,21 @@ def read(
batch_size)

return iter(
FileSetReadOp(
self._storage.location, self._storage.metadata,
self._storage.data_files(filter_, snapshot_id=snapshot_id),
read_options))
FileSetReadOp(self._storage.location,
self._storage.metadata,
self._storage.data_files(filter_,
snapshot_id=snapshot_id),
options=read_options))

@StorageMixin.reload
def diff(self, start_version: Version,
end_version: Version) -> Iterator[ChangeData]:
def diff(self,
start_version: Version,
end_version: Version,
batch_size: Optional[int] = None) -> Iterator[ChangeData]:
return read_change_data(self._storage,
self._storage.version_to_snapshot_id(start_version),
self._storage.version_to_snapshot_id(end_version))
self._storage.version_to_snapshot_id(end_version),
ReadOptions(batch_size=batch_size))

@StorageMixin.transactional
def append(self, data: InputData) -> Optional[rt.Patch]:
Expand Down
5 changes: 3 additions & 2 deletions python/src/space/ray/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
return read_tasks

def _read_fn(self, index_file: rt.DataFile) -> Callable[..., Iterator[Block]]:
return partial(FileSetReadOp, self._storage.location,
self._storage.metadata, rt.FileSet(index_files=[index_file]),
return partial(FileSetReadOp,
self._storage.location, self._storage.metadata,
rt.FileSet(index_files=[index_file]), None,
self._read_options) # type: ignore[return-value]


Expand Down
14 changes: 9 additions & 5 deletions python/src/space/ray/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ def read(
join_options).to_arrow_refs():
yield ray.get(ref)

def diff(self, start_version: Union[Version],
end_version: Union[Version]) -> Iterator[ChangeData]:
def diff(self,
start_version: Union[Version],
end_version: Union[Version],
batch_size: Optional[int] = None) -> Iterator[ChangeData]:
self._source_storage.reload()
source_changes = read_change_data(
self._source_storage,
self._source_storage.version_to_snapshot_id(start_version),
self._source_storage.version_to_snapshot_id(end_version))
self._source_storage.version_to_snapshot_id(end_version),
ReadOptions(batch_size=batch_size))

for change in source_changes:
# TODO: skip processing the data for deletions; the caller is usually
Expand Down Expand Up @@ -151,7 +154,8 @@ def read(
yield ray.get(ref)

def refresh(self,
target_version: Optional[Version] = None) -> List[JobResult]:
target_version: Optional[Version] = None,
batch_size: Optional[int] = None) -> List[JobResult]:
"""Refresh the materialized view by synchronizing from source dataset."""
source_snapshot_id = self._source_storage.metadata.current_snapshot_id
if target_version is None:
Expand All @@ -171,7 +175,7 @@ def refresh(self,
previous_snapshot_id: Optional[int] = None

txn = self._start_txn()
for change in self.diff(start_snapshot_id, end_snapshot_id):
for change in self.diff(start_snapshot_id, end_snapshot_id, batch_size):
# Commit when changes from the same snapshot end.
if (previous_snapshot_id is not None and
change.snapshot_id != previous_snapshot_id):
Expand Down
2 changes: 1 addition & 1 deletion python/tests/core/ops/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_read_all_types(self, tmp_path, all_types_schema,
storage.metadata,
storage.data_files(),
# pylint: disable=singleton-comparison
ReadOptions(filter_=pc.field("bool") == True))
options=ReadOptions(filter_=pc.field("bool") == True))
results = list(iter(read_op))
assert len(results) == 1
assert list(iter(read_op))[0] == pa.Table.from_pydict({
Expand Down
12 changes: 12 additions & 0 deletions python/tests/core/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ def test_read_and_write_should_reload_storage(self, sample_dataset):
assert local_runner1.read_all() == pa.concat_tables(
[sample_data1, sample_data2, sample_data3, sample_data4])

def test_diff_batch_size(self, sample_dataset):
ds = sample_dataset

ds.local().append(_generate_data(range(5)))

assert list(ds.local().diff(0, 1, batch_size=3)) == [
ChangeData(ds.storage.metadata.current_snapshot_id, ChangeType.ADD,
_generate_data(range(3))),
ChangeData(ds.storage.metadata.current_snapshot_id, ChangeType.ADD,
_generate_data(range(3, 5)))
]

def test_add_read_remove_tag(self, sample_dataset):
ds = sample_dataset
local_runner = ds.local()
Expand Down
Loading

0 comments on commit de825aa

Please sign in to comment.