diff --git a/python/src/space/core/ops/append.py b/python/src/space/core/ops/append.py index 6c0f84c..761cce6 100644 --- a/python/src/space/core/ops/append.py +++ b/python/src/space/core/ops/append.py @@ -84,12 +84,12 @@ def __init__(self, location: str, metadata: meta.StorageMetadata): StoragePaths.__init__(self, location) self._metadata = metadata - self._schema = arrow.arrow_schema(self._metadata.schema.fields) - + record_fields = set(self._metadata.schema.record_fields) + self._schema = arrow.arrow_schema(self._metadata.schema.fields, + record_fields, + physical=True) self._index_fields, self._record_fields = arrow.classify_fields( - self._schema, - set(self._metadata.schema.record_fields), - selected_fields=None) + self._schema, record_fields, selected_fields=None) # Data file writers. self._index_writer_info: Optional[_IndexWriterInfo] = None diff --git a/python/src/space/core/schema/arrow.py b/python/src/space/core/schema/arrow.py index 898c240..8f7961b 100644 --- a/python/src/space/core/schema/arrow.py +++ b/python/src/space/core/schema/arrow.py @@ -20,9 +20,9 @@ import pyarrow as pa from substrait.type_pb2 import NamedStruct, Type -from space.core.utils.constants import UTF_8 from space.core.schema import constants from space.core.schema.types import TfFeatures +from space.core.utils.constants import UTF_8 _PARQUET_FIELD_ID_KEY = b"PARQUET:field_id" @@ -49,11 +49,13 @@ def next(self) -> str: return name -def arrow_schema(fields: NamedStruct) -> pa.Schema: +def arrow_schema(fields: NamedStruct, record_fields: Set[str], + physical: bool) -> pa.Schema: """Return Arrow schema from Substrait fields. Args: fields: schema fields in the Substrait format. + record_fields: a set of record field names. physical: if true, return the physical schema. Physical schema matches with the underlying index (Parquet) file schema. Record fields are stored by their references, e.g., row position in ArrayRecord file. @@ -61,19 +63,28 @@ def arrow_schema(fields: NamedStruct) -> pa.Schema: return pa.schema( _arrow_fields( _NamesVisitor(fields.names), # type: ignore[arg-type] - fields.struct.types)) # type: ignore[arg-type] + fields.struct.types, # type: ignore[arg-type] + record_fields, + physical)) -def _arrow_fields(names_visitor: _NamesVisitor, - types: List[Type]) -> List[pa.Field]: +def _arrow_fields(names_visitor: _NamesVisitor, types: List[Type], + record_fields: Set[str], physical: bool) -> List[pa.Field]: fields: List[pa.Field] = [] for type_ in types: name = names_visitor.next() - arrow_field = pa.field(name, - _arrow_type(type_, names_visitor), - metadata=field_metadata(_substrait_field_id(type_))) - fields.append(arrow_field) + + if physical and name in record_fields: + arrow_type: pa.DataType = pa.struct( + record_address_types()) # type: ignore[arg-type] + else: + arrow_type = _arrow_type(type_, names_visitor) + + fields.append( + pa.field(name, + arrow_type, + metadata=field_metadata(_substrait_field_id(type_)))) return fields diff --git a/python/src/space/core/schema/types/tf_features.py b/python/src/space/core/schema/types/tf_features.py index 1e3dcbd..c183000 100644 --- a/python/src/space/core/schema/types/tf_features.py +++ b/python/src/space/core/schema/types/tf_features.py @@ -15,7 +15,7 @@ """Define a custom Arrow type for Tensorflow Dataset Features.""" from __future__ import annotations -from typing import Any +from typing import Any, Union import json import pyarrow as pa @@ -47,9 +47,12 @@ def __arrow_ext_serialize__(self) -> bytes: def __arrow_ext_deserialize__( cls, storage_type: pa.DataType, # pylint: disable=unused-argument - serialized: bytes) -> TfFeatures: - return TfFeatures( - f.FeaturesDict.from_json(json.loads(serialized.decode(UTF_8)))) + serialized: Union[bytes, str] + ) -> TfFeatures: + if isinstance(serialized, bytes): + serialized = serialized.decode(UTF_8) + + return TfFeatures(f.FeaturesDict.from_json(json.loads(serialized))) def serialize(self, value: Any) -> bytes: """Serialize value using the provided features_dict.""" diff --git a/python/tests/core/manifests/test_record.py b/python/tests/core/manifests/test_record.py index b254915..a77be0f 100644 --- a/python/tests/core/manifests/test_record.py +++ b/python/tests/core/manifests/test_record.py @@ -21,7 +21,6 @@ class TestRecordManifestWriter: def test_write(self, tmp_path): - data_dir = tmp_path / "dataset" / "data" metadata_dir = tmp_path / "dataset" / "metadata" metadata_dir.mkdir(parents=True) diff --git a/python/tests/core/ops/test_append.py b/python/tests/core/ops/test_append.py index bd4b67b..8e4ab95 100644 --- a/python/tests/core/ops/test_append.py +++ b/python/tests/core/ops/test_append.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List +import numpy as np import pyarrow as pa import pyarrow.parquet as pq +from tensorflow_datasets import features # type: ignore[import-untyped] from space.core.ops import LocalAppendOp import space.core.proto.metadata_pb2 as meta +from space.core.schema.types import TfFeatures from space.core.storage import Storage class TestLocalAppendOp: + # TODO: to add tests using Arrow table input. + def test_write_pydict_all_types(self, tmp_path): location = tmp_path / "dataset" schema = pa.schema([ @@ -75,6 +81,82 @@ def test_write_pydict_all_types(self, tmp_path): assert patch.storage_statistics_update == meta.StorageStatistics( num_rows=5, index_compressed_bytes=114, index_uncompressed_bytes=126) + def test_write_pydict_with_record_fields(self, tmp_path): + tf_features_images = features.FeaturesDict( + {"images": features.Image(shape=(None, None, 3), dtype=np.uint8)}) + tf_features_objects = features.FeaturesDict({ + "objects": + features.Sequence({ + "bbox": features.BBoxFeature(), + "id": np.int64 + }), + }) + + location = tmp_path / "dataset" + schema = pa.schema([ + pa.field("int64", pa.int64()), + pa.field("string", pa.string()), + pa.field("images", TfFeatures(tf_features_images)), + pa.field("objects", TfFeatures(tf_features_objects)) + ]) + storage = Storage.create(location=str(location), + schema=schema, + primary_keys=["int64"], + record_fields=["images", "objects"]) + + op = LocalAppendOp(str(location), storage.metadata) + + op.write({ + "int64": [1, 2, 3], + "string": ["a", "b", "c"], + "images": [b"images0", b"images1", b"images2"], + "objects": [b"objects0", b"objects1", b"objects2"] + }) + op.write({ + "int64": [0, 10], + "string": ["A", "z"], + "images": [b"images3", b"images4"], + "objects": [b"objects3", b"objects4"] + }) + + patch = op.finish() + assert patch is not None + + # Validate index manifest files. + index_manifest = self._read_manifests( + storage, list(patch.added_index_manifest_files)) + assert index_manifest == { + "_FILE": index_manifest["_FILE"], + "_INDEX_COMPRESSED_BYTES": [114], + "_INDEX_UNCOMPRESSED_BYTES": [126], + "_NUM_ROWS": [5], + "_STATS_f0": [{ + "_MAX": 10, + "_MIN": 0 + }] + } + + # Validate record manifest files. + record_manifest = self._read_manifests( + storage, list(patch.added_record_manifest_files)) + assert record_manifest == { + "_FILE": record_manifest["_FILE"], + "_FIELD_ID": [2, 3], + "_NUM_ROWS": [5, 5], + "_UNCOMPRESSED_BYTES": [55, 60] + } + + # Data file exists. + self._check_file_exists(location, index_manifest["_FILE"]) + self._check_file_exists(location, record_manifest["_FILE"]) + + # Validate statistics. + assert patch.storage_statistics_update == meta.StorageStatistics( + num_rows=5, + index_compressed_bytes=114, + index_uncompressed_bytes=126, + record_uncompressed_bytes=115) + def test_empty_op_return_none(self, tmp_path): location = tmp_path / "dataset" schema = pa.schema([pa.field("int64", pa.int64())]) @@ -85,3 +167,15 @@ def test_empty_op_return_none(self, tmp_path): op = LocalAppendOp(str(location), storage.metadata) assert op.finish() is None + + def _read_manifests(self, storage: Storage, + file_paths: List[str]) -> pa.Table: + manifests = [] + for f in file_paths: + manifests.append(pq.read_table(storage.full_path(f))) + + return pa.concat_tables(manifests).to_pydict() + + def _check_file_exists(self, location, file_paths: List[str]): + for f in file_paths: + assert (location / f).exists() diff --git a/python/tests/core/schema/conftest.py b/python/tests/core/schema/conftest.py index 4205696..58d4f5f 100644 --- a/python/tests/core/schema/conftest.py +++ b/python/tests/core/schema/conftest.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest +import json +import numpy as np import pyarrow as pa +import pytest +from tensorflow_datasets import features # type: ignore[import-untyped] from substrait.type_pb2 import NamedStruct, Type from space.core.schema.arrow import field_metadata +from space.core.schema.types import TfFeatures @pytest.fixture @@ -92,3 +96,31 @@ def sample_arrow_schema(): ]), metadata=field_metadata(260)) ]) + + +@pytest.fixture +def tf_features(): + return features.FeaturesDict( + {"images": features.Image(shape=(None, None, 3), dtype=np.uint8)}) + + +@pytest.fixture +def tf_features_substrait_fields(tf_features): # pylint: disable=redefined-outer-name + return NamedStruct( + names=["int64", "features"], + struct=Type.Struct(types=[ + Type(i64=Type.I64(type_variation_reference=0)), + Type(user_defined=Type.UserDefined(type_parameters=[ + Type.Parameter(string="TF_FEATURES"), + Type.Parameter(string=json.dumps(tf_features.to_json())) + ], + type_variation_reference=1)) + ])) + + +@pytest.fixture +def tf_features_arrow_schema(tf_features): # pylint: disable=redefined-outer-name + return pa.schema([ + pa.field("int64", pa.int64(), metadata=field_metadata(0)), + pa.field("features", TfFeatures(tf_features), metadata=field_metadata(1)) + ]) diff --git a/python/tests/core/schema/test_arrow.py b/python/tests/core/schema/test_arrow.py index 420c220..9d7c409 100644 --- a/python/tests/core/schema/test_arrow.py +++ b/python/tests/core/schema/test_arrow.py @@ -15,6 +15,7 @@ import pyarrow as pa from space.core.schema import arrow +from space.core.schema.arrow import field_metadata def test_field_metadata(): @@ -27,8 +28,33 @@ def test_field_id(): b"123"})) == 123 -def test_arrow_schema(sample_substrait_fields, sample_arrow_schema): - assert sample_arrow_schema == arrow.arrow_schema(sample_substrait_fields) +def test_arrow_schema_logical_without_records(sample_substrait_fields, + sample_arrow_schema): + assert arrow.arrow_schema(sample_substrait_fields, [], + False) == sample_arrow_schema + + +def test_arrow_schema_logical_with_records(tf_features_substrait_fields, + tf_features_arrow_schema): + assert arrow.arrow_schema(tf_features_substrait_fields, [], + False) == tf_features_arrow_schema + + +def test_arrow_schema_physical_without_records(sample_substrait_fields, + sample_arrow_schema): + assert arrow.arrow_schema(sample_substrait_fields, [], + True) == sample_arrow_schema + + +def test_arrow_schema_physical_with_records(tf_features_substrait_fields): + arrow_schema = pa.schema([ + pa.field("int64", pa.int64(), metadata=field_metadata(0)), + pa.field("features", + pa.struct([("_FILE", pa.string()), ("_ROW_ID", pa.int32())]), + metadata=field_metadata(1)) + ]) + assert arrow.arrow_schema(tf_features_substrait_fields, ["features"], + True) == arrow_schema def test_field_name_to_id_dict(sample_arrow_schema): diff --git a/python/tests/core/schema/types/test_tf_features.py b/python/tests/core/schema/types/test_tf_features.py index 1f988d6..5d0d019 100644 --- a/python/tests/core/schema/types/test_tf_features.py +++ b/python/tests/core/schema/types/test_tf_features.py @@ -53,10 +53,16 @@ def test_arrow_ext_serialize_deserialize(self, tf_features, sample_objects): "type"] == "tensorflow_datasets.core.features.features_dict.FeaturesDict" # pylint: disable=line-too-long assert "sequence" in features_dict["content"]["features"]["objects"] + # Bytes input. tf_features = TfFeatures.__arrow_ext_deserialize__(storage_type=None, serialized=serialized) assert len(tf_features.serialize(sample_objects)) > 0 + # String input. + tf_features = TfFeatures.__arrow_ext_deserialize__( + storage_type=None, serialized=serialized.decode(UTF_8)) + assert len(tf_features.serialize(sample_objects)) > 0 + def test_serialize_deserialize(self, tf_features, sample_objects): value_bytes = tf_features.serialize(sample_objects) assert len(value_bytes) > 0