From 5a6a839a42f1df9ff1f56fde3a14f34f3bb0ba6b Mon Sep 17 00:00:00 2001 From: coufon Date: Tue, 26 Dec 2023 04:23:47 +0000 Subject: [PATCH] Add a TFDS data source for Space dataset --- python/src/space/core/datasets.py | 10 ++ python/src/space/core/ops/__init__.py | 2 +- python/src/space/core/ops/read.py | 39 ++++- python/src/space/core/runners.py | 28 +++- python/src/space/core/serializers/base.py | 35 +++-- python/src/space/core/storage.py | 15 +- python/src/space/tf/data_sources.py | 136 ++++++++++++++++++ python/tests/core/ops/test_read.py | 4 +- .../core/schema/types/test_tf_features.py | 4 +- python/tests/core/test_runners.py | 2 +- python/tests/tf/test_data_sources.py | 85 +++++++++++ 11 files changed, 330 insertions(+), 30 deletions(-) create mode 100644 python/src/space/tf/data_sources.py create mode 100644 python/tests/tf/test_data_sources.py diff --git a/python/src/space/core/datasets.py b/python/src/space/core/datasets.py index 21d7498..868226b 100644 --- a/python/src/space/core/datasets.py +++ b/python/src/space/core/datasets.py @@ -20,6 +20,7 @@ import pyarrow as pa from space.core.runners import LocalRunner +from space.core.serializers.base import DictSerializer from space.core.storage import Storage @@ -48,6 +49,15 @@ def load(cls, location: str) -> Dataset: """Load an existing dataset from the given location.""" return Dataset(Storage.load(location)) + @property + def schema(self) -> pa.Schema: + """Return the dataset schema.""" + return self._storage.logical_schema + + def serializer(self) -> DictSerializer: + """Return a serializer (deserializer) for the dataset.""" + return DictSerializer(self.schema) + def local(self) -> LocalRunner: """Get a runner that runs operations locally.""" return LocalRunner(self._storage) diff --git a/python/src/space/core/ops/__init__.py b/python/src/space/core/ops/__init__.py index 6d7f9bf..93f46aa 100644 --- a/python/src/space/core/ops/__init__.py +++ b/python/src/space/core/ops/__init__.py @@ -16,4 +16,4 @@ from space.core.ops.append import LocalAppendOp from space.core.ops.delete import FileSetDeleteOp -from space.core.ops.read import FileSetReadOp +from space.core.ops.read import FileSetReadOp, ReadOptions diff --git a/python/src/space/core/ops/read.py b/python/src/space/core/ops/read.py index 529559b..6915657 100644 --- a/python/src/space/core/ops/read.py +++ b/python/src/space/core/ops/read.py @@ -18,6 +18,7 @@ from abc import abstractmethod from typing import Iterator, Dict, List, Tuple, Optional +from dataclasses import dataclass import numpy as np import pyarrow as pa import pyarrow.parquet as pq @@ -35,6 +36,18 @@ _RECORD_KEY_FIELD = "__RECORD_KEY" +@dataclass +class ReadOptions: + """Options of reading data.""" + # Filters on index fields. + filter_: Optional[pc.Expression] = None + # When specified, only read the given fields instead of all fields. + fields: Optional[List[str]] = None + # If true, read the references (e.g., address) of read record fields instead + # of values. + reference_read: bool = False + + class BaseReadOp(BaseOp): """Abstract base read operation class.""" @@ -56,7 +69,7 @@ def __init__(self, location: str, metadata: meta.StorageMetadata, file_set: runtime.FileSet, - filter_: Optional[pc.Expression] = None): + options: Optional[ReadOptions] = None): StoragePaths.__init__(self, location) # TODO: to validate that filter_ does not contain record files. @@ -64,12 +77,23 @@ def __init__(self, self._metadata = metadata self._file_set = file_set + # TODO: to validate options, e.g., fields are valid. + self._options = ReadOptions() if options is None else options + record_fields = set(self._metadata.schema.record_fields) self._physical_schema = arrow.arrow_schema(self._metadata.schema.fields, record_fields, physical=True) + + if self._options.fields is None: + self._selected_fields = [f.name for f in self._physical_schema] + else: + self._selected_fields = self._options.fields + self._index_fields, self._record_fields = arrow.classify_fields( - self._physical_schema, record_fields, selected_fields=None) + self._physical_schema, + record_fields, + selected_fields=set(self._selected_fields)) self._index_field_ids = set(schema_utils.field_ids(self._index_fields)) @@ -77,15 +101,18 @@ def __init__(self, for f in self._record_fields: self._record_fields_dict[f.field_id] = f - self._filter = filter_ - def __iter__(self) -> Iterator[pa.Table]: for file in self._file_set.index_files: yield self._read_index_and_record(file.path) def _read_index_and_record(self, index_path: str) -> pa.Table: - index_data = pq.read_table(self.full_path(index_path), - filters=self._filter) # type: ignore[arg-type] + index_data = pq.read_table( + self.full_path(index_path), + columns=self._selected_fields, + filters=self._options.filter_) # type: ignore[arg-type] + + if self._options.reference_read: + return index_data index_column_ids: List[int] = [] record_columns: List[Tuple[int, pa.Field]] = [] diff --git a/python/src/space/core/runners.py b/python/src/space/core/runners.py index 181cdfc..46fe4a9 100644 --- a/python/src/space/core/runners.py +++ b/python/src/space/core/runners.py @@ -16,13 +16,16 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Iterator, Optional +from typing import Iterator, List, Optional from absl import logging # type: ignore[import-untyped] import pyarrow as pa import pyarrow.compute as pc -from space.core.ops import FileSetDeleteOp, FileSetReadOp, LocalAppendOp +from space.core.ops import FileSetDeleteOp +from space.core.ops import FileSetReadOp +from space.core.ops import LocalAppendOp +from space.core.ops import ReadOptions from space.core.ops.base import InputData import space.core.proto.runtime_pb2 as runtime from space.core.storage import Storage @@ -38,9 +41,20 @@ def __init__(self, storage: Storage): @abstractmethod def read(self, filter_: Optional[pc.Expression] = None, - snapshot_id: Optional[int] = None) -> Iterator[pa.Table]: + fields: Optional[List[str]] = None, + snapshot_id: Optional[int] = None, + reference_read: bool = False) -> Iterator[pa.Table]: """Read data from the dataset as an iterator.""" + def read_all(self, + filter_: Optional[pc.Expression] = None, + fields: Optional[List[str]] = None, + snapshot_id: Optional[int] = None, + reference_read: bool = False) -> pa.Table: + """Read data from the dataset as an Arrow table.""" + return pa.concat_tables( + list(self.read(filter_, fields, snapshot_id, reference_read))) + @abstractmethod def append(self, data: InputData) -> runtime.JobResult: """Append data into the dataset.""" @@ -72,12 +86,16 @@ class LocalRunner(BaseRunner): def read(self, filter_: Optional[pc.Expression] = None, - snapshot_id: Optional[int] = None) -> Iterator[pa.Table]: + fields: Optional[List[str]] = None, + snapshot_id: Optional[int] = None, + reference_read: bool = False) -> Iterator[pa.Table]: return iter( FileSetReadOp( self._storage.location, self._storage.metadata, self._storage.data_files(filter_, snapshot_id=snapshot_id), - filter_)) + ReadOptions(filter_=filter_, + fields=fields, + reference_read=reference_read))) def append(self, data: InputData) -> runtime.JobResult: op = LocalAppendOp(self._storage.location, self._storage.metadata) diff --git a/python/src/space/core/serializers/base.py b/python/src/space/core/serializers/base.py index fc76c10..46d9956 100644 --- a/python/src/space/core/serializers/base.py +++ b/python/src/space/core/serializers/base.py @@ -15,7 +15,7 @@ """Serializers (and deserializers) for unstructured record fields.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from typing_extensions import TypeAlias import pyarrow as pa @@ -63,17 +63,28 @@ def __init__(self, logical_schema: pa.Schema): if isinstance(field.type, FieldSerializer): self._serializers[field.name] = field.type + def field_serializer(self, field: str) -> Optional[FieldSerializer]: + """Return the FieldSerializer of a given field, or None if not found.""" + if field not in self._serializers: + return None + + return self._serializers[field] + def serialize(self, value: DictData) -> DictData: """Serialize a value. Args: value: a dict of numpy-like nested dicts. """ - for name, ser in self._serializers.items(): - if name in value: - value[name] = [ser.serialize(d) for d in value[name]] + result = {} + for field_name, value_batch in value.items(): + if field_name in self._serializers: + ser = self._serializers[field_name] + result[field_name] = [ser.serialize(v) for v in value_batch] + else: + result[field_name] = value_batch - return value + return result def deserialize(self, value_bytes: DictData) -> DictData: """Deserialize a dict of bytes to a dict of values. @@ -81,8 +92,12 @@ def deserialize(self, value_bytes: DictData) -> DictData: Returns: A dict of numpy-like nested dicts. """ - for name, ser in self._serializers.items(): - if name in value_bytes: - value_bytes[name] = [ser.deserialize(d) for d in value_bytes[name]] - - return value_bytes + result = {} + for field_name, value_batch in value_bytes.items(): + if field_name in self._serializers: + ser = self._serializers[field_name] + result[field_name] = [ser.deserialize(v) for v in value_batch] + else: + result[field_name] = value_batch + + return result diff --git a/python/src/space/core/storage.py b/python/src/space/core/storage.py index b75bda9..3f90da3 100644 --- a/python/src/space/core/storage.py +++ b/python/src/space/core/storage.py @@ -46,9 +46,12 @@ def __init__(self, location: str, metadata: meta.StorageMetadata): self._fs = create_fs(location) record_fields = set(self._metadata.schema.record_fields) - self._physical_schema = arrow.arrow_schema(self._metadata.schema.fields, - record_fields, - physical=True) + self._logical_schema = arrow.arrow_schema(self._metadata.schema.fields, + record_fields, + physical=False) + self._physical_schema = arrow.logical_to_physical_schema( + self._logical_schema, record_fields) + self._field_name_ids: Dict[str, int] = arrow.field_name_to_id_dict( self._physical_schema) @@ -57,6 +60,11 @@ def metadata(self) -> meta.StorageMetadata: """Return the storage metadata.""" return self._metadata + @property + def logical_schema(self) -> pa.Schema: + """Return the user specified schema.""" + return self._logical_schema + @property def physical_schema(self) -> pa.Schema: """Return the physcal schema that uses reference for record fields.""" @@ -89,6 +97,7 @@ def create( # TODO: to verify that location is an empty directory. # TODO: to verify primary key fields and record_fields (and types) are # valid. + # TODO: to auto infer record_fields. field_id_mgr = FieldIdManager() schema = field_id_mgr.assign_field_ids(schema) diff --git a/python/src/space/tf/data_sources.py b/python/src/space/tf/data_sources.py new file mode 100644 index 0000000..994e3eb --- /dev/null +++ b/python/src/space/tf/data_sources.py @@ -0,0 +1,136 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""TFDS random access data source using Space dataset.""" + +from collections.abc import Sequence as AbcSequence +from typing import Any, List, Sequence, Union + +from absl import logging # type: ignore[import-untyped] +import pyarrow as pa +# pylint: disable=line-too-long +from tensorflow_datasets.core.utils import shard_utils # type: ignore[import-untyped] +from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_data_source as ards # type: ignore[import-untyped] + +from space import Dataset +from space.core.schema import constants + + +class SpaceDataSource(AbcSequence): + """TFDS random access data source using Space dataset.""" + + def __init__(self, ds_or_location: Union[Dataset, str], + feature_fields: List[str]): + """ + TODO: to support a filder on index fields. + + Args: + ds_or_location: a dataset object or location. + feature_fields: the record field containing data to read. + """ + # TODO: to auto infer feature_fields. + if isinstance(ds_or_location, str): + self._ds = Dataset.load(ds_or_location) + else: + self._ds = ds_or_location + + assert len(feature_fields) == 1, "Only support one feature field" + self._feature_field: str = feature_fields[0] + # TODO: to verify the feature field is a record field? + + self._serializer = self._ds.serializer().field_serializer( + self._feature_field) + # pylint: disable=line-too-long + assert self._serializer is not None, f"Feature field {self._feature_field} must be a record field" + + self._data_source = ards.ArrayRecordDataSource( + self._file_instructions(self._read_index_and_address())) + self._length = len(self._data_source) + + def __len__(self) -> int: + return self._length + + def __iter__(self): + for i in range(self._length): + yield self[i] + + def __getitem__(self, record_key: int) -> Any: # type: ignore[override] + if not isinstance(record_key, int): + logging.error( + "Calling ArrayRecordDataSource.__getitem__() with sequence " + "of record keys (%s) is deprecated. Either pass a single " + "integer or switch to __getitems__().", + record_key, + ) + + return self.__getitems__(record_key) + + record = self._data_source[record_key] + return self._serializer.deserialize(record) + + def __getitems__(self, record_keys: Sequence[int]) -> Sequence[Any]: + records = self._data_source.__getitems__(record_keys) + if len(record_keys) != len(records): + raise IndexError(f"Requested {len(record_keys)} records but got" + f" {len(records)} records." + f"{record_keys=}, {records=}") + + return [self._serializer.deserialize(record) for record in records] + + def _read_index_and_address(self) -> pa.Table: + """Read index and record address columns. + + TODO: to read index for filters. + """ + return self._ds.local().read_all(fields=[self._feature_field], + reference_read=True) + + def _file_instructions( + self, record_addresses: pa.Table) -> List[shard_utils.FileInstruction]: + """Convert record addresses to ArrayRecord file instructions.""" + record_address_table = pa.Table.from_arrays( + record_addresses.column( + self._feature_field).flatten(), # type: ignore[arg-type] + [constants.FILE_PATH_FIELD, constants.ROW_ID_FIELD]) + aggregated = record_address_table.group_by( + constants.FILE_PATH_FIELD).aggregate([(constants.ROW_ID_FIELD, "list") + ]).to_pydict() + + file_instructions = [] + for file_path, indexes in zip( + aggregated[constants.FILE_PATH_FIELD], + aggregated[f"{constants.ROW_ID_FIELD}_list"]): + full_file_path = self._ds._storage.full_path(file_path) # pylint: disable=protected-access + if not indexes: + continue + + indexes.sort() + + previous_idx = indexes[0] + start, end = indexes[0], indexes[-1] + for idx in indexes: + if idx != previous_idx + 1 or idx == end: + if idx == end: + previous_idx = idx + + # TODO: to populate FileInstruction.examples_in_shard. + file_instructions.append( + shard_utils.FileInstruction(filename=full_file_path, + skip=start, + take=previous_idx - start + 1, + examples_in_shard=end + 1)) + start = idx + previous_idx = idx + + return file_instructions diff --git a/python/tests/core/ops/test_read.py b/python/tests/core/ops/test_read.py index ce3e0f9..cde7f54 100644 --- a/python/tests/core/ops/test_read.py +++ b/python/tests/core/ops/test_read.py @@ -16,7 +16,7 @@ import pyarrow.compute as pc from space.core.ops import LocalAppendOp -from space.core.ops import FileSetReadOp +from space.core.ops import FileSetReadOp, ReadOptions from space.core.storage import Storage @@ -51,7 +51,7 @@ def test_read_all_types(self, tmp_path, all_types_schema, storage.metadata, storage.data_files(), # pylint: disable=singleton-comparison - filter_=pc.field("bool") == True) + ReadOptions(filter_=pc.field("bool") == True)) results = list(iter(read_op)) assert len(results) == 1 assert list(iter(read_op))[0] == pa.Table.from_pydict({ diff --git a/python/tests/core/schema/types/test_tf_features.py b/python/tests/core/schema/types/test_tf_features.py index cc00685..9d4ff67 100644 --- a/python/tests/core/schema/types/test_tf_features.py +++ b/python/tests/core/schema/types/test_tf_features.py @@ -14,7 +14,7 @@ import json import numpy as np -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_equal import pyarrow as pa import pytest import tensorflow_datasets as tfds # type: ignore[import-untyped] @@ -102,4 +102,4 @@ def test_dict_serialize_deserialize(self, tf_features): np.array([[0.3, 0.8, 0.5, 1.]], dtype=np.float32)) assert_array_equal(objects["id"], np.array([123])) - assert serializer.deserialize(serialized_data) == data + assert_equal(serializer.deserialize(serialized_data), data) diff --git a/python/tests/core/test_runners.py b/python/tests/core/test_runners.py index c5d4349..af3cc3e 100644 --- a/python/tests/core/test_runners.py +++ b/python/tests/core/test_runners.py @@ -54,7 +54,7 @@ def test_data_mutation_and_read(self, tmp_path): def _read_pyarrow(runner: LocalRunner, filter_: Optional[pc.Expression] = None) -> pa.Table: - return pa.concat_tables((list(runner.read(filter_)))) + return runner.read_all(filter_) def _generate_data(ids: Iterable[int]) -> pa.Table: diff --git a/python/tests/tf/test_data_sources.py b/python/tests/tf/test_data_sources.py new file mode 100644 index 0000000..f7df212 --- /dev/null +++ b/python/tests/tf/test_data_sources.py @@ -0,0 +1,85 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy as np +from numpy.testing import assert_equal +import pyarrow as pa +import pytest +from tensorflow_datasets import features as f # type: ignore[import-untyped] + +from space import Dataset, TfFeatures +import space.core.proto.metadata_pb2 as meta +from space.core.utils.lazy_imports_utils import array_record_module as ar +from space.core.utils.uuids import uuid_ + +from space.tf.data_sources import SpaceDataSource + + +class TestSpaceDataSource: + + @pytest.fixture + def tf_features(self): + features_dict = f.FeaturesDict({ + "image_id": + np.int64, + "objects": + f.Sequence({"bbox": f.BBoxFeature()}), + }) + return TfFeatures(features_dict) + + def test_read_space_data_source(self, tmp_path, tf_features): + schema = pa.schema([("id", pa.int64()), ("features", tf_features)]) + ds = Dataset.create(str(tmp_path / "dataset"), + schema, + primary_keys=["id"], + record_fields=["features"]) + + # TODO: to test more records per file. + input_data = [{ + "id": [123], + "features": [{ + "image_id": 123, + "objects": { + "bbox": np.array([[0.3, 0.8, 0.5, 1.0]], np.float32) + } + }] + }, { + "id": [456, 789], + "features": [{ + "image_id": 456, + "objects": { + "bbox": + np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]], + np.float32) + } + }, { + "image_id": 789, + "objects": { + "bbox": + np.array([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]], + np.float32) + } + }] + }] + + runner = ds.local() + serializer = ds.serializer() + for data in input_data: + runner.append(serializer.serialize(data)) + + data_source = SpaceDataSource(ds, ["features"]) + assert_equal(data_source[0], input_data[0]["features"][0]) + assert_equal(data_source[1], input_data[1]["features"][0])