Skip to content

Commit

Permalink
Add unit tests of appending records and schema methods
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Dec 22, 2023
1 parent 243ce7e commit 6122156
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 22 deletions.
10 changes: 5 additions & 5 deletions python/src/space/core/ops/append.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def __init__(self, location: str, metadata: meta.StorageMetadata):
StoragePaths.__init__(self, location)

self._metadata = metadata
self._schema = arrow.arrow_schema(self._metadata.schema.fields)

record_fields = set(self._metadata.schema.record_fields)
self._schema = arrow.arrow_schema(self._metadata.schema.fields,
record_fields,
physical=True)
self._index_fields, self._record_fields = arrow.classify_fields(
self._schema,
set(self._metadata.schema.record_fields),
selected_fields=None)
self._schema, record_fields, selected_fields=None)

# Data file writers.
self._index_writer_info: Optional[_IndexWriterInfo] = None
Expand Down
29 changes: 20 additions & 9 deletions python/src/space/core/schema/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import pyarrow as pa
from substrait.type_pb2 import NamedStruct, Type

from space.core.utils.constants import UTF_8
from space.core.schema import constants
from space.core.schema.types import TfFeatures
from space.core.utils.constants import UTF_8

_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"

Expand All @@ -49,31 +49,42 @@ def next(self) -> str:
return name


def arrow_schema(fields: NamedStruct) -> pa.Schema:
def arrow_schema(fields: NamedStruct, record_fields: Set[str],
physical: bool) -> pa.Schema:
"""Return Arrow schema from Substrait fields.
Args:
fields: schema fields in the Substrait format.
record_fields: a set of record field names.
physical: if true, return the physical schema. Physical schema matches with
the underlying index (Parquet) file schema. Record fields are stored by
their references, e.g., row position in ArrayRecord file.
"""
return pa.schema(
_arrow_fields(
_NamesVisitor(fields.names), # type: ignore[arg-type]
fields.struct.types)) # type: ignore[arg-type]
fields.struct.types, # type: ignore[arg-type]
record_fields,
physical))


def _arrow_fields(names_visitor: _NamesVisitor,
types: List[Type]) -> List[pa.Field]:
def _arrow_fields(names_visitor: _NamesVisitor, types: List[Type],
record_fields: Set[str], physical: bool) -> List[pa.Field]:
fields: List[pa.Field] = []

for type_ in types:
name = names_visitor.next()
arrow_field = pa.field(name,
_arrow_type(type_, names_visitor),
metadata=field_metadata(_substrait_field_id(type_)))
fields.append(arrow_field)

if physical and name in record_fields:
arrow_type: pa.DataType = pa.struct(
record_address_types()) # type: ignore[arg-type]
else:
arrow_type = _arrow_type(type_, names_visitor)

fields.append(
pa.field(name,
arrow_type,
metadata=field_metadata(_substrait_field_id(type_))))

return fields

Expand Down
11 changes: 7 additions & 4 deletions python/src/space/core/schema/types/tf_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Define a custom Arrow type for Tensorflow Dataset Features."""

from __future__ import annotations
from typing import Any
from typing import Any, Union

import json
import pyarrow as pa
Expand Down Expand Up @@ -47,9 +47,12 @@ def __arrow_ext_serialize__(self) -> bytes:
def __arrow_ext_deserialize__(
cls,
storage_type: pa.DataType, # pylint: disable=unused-argument
serialized: bytes) -> TfFeatures:
return TfFeatures(
f.FeaturesDict.from_json(json.loads(serialized.decode(UTF_8))))
serialized: Union[bytes, str]
) -> TfFeatures:
if isinstance(serialized, bytes):
serialized = serialized.decode(UTF_8)

return TfFeatures(f.FeaturesDict.from_json(json.loads(serialized)))

def serialize(self, value: Any) -> bytes:
"""Serialize value using the provided features_dict."""
Expand Down
1 change: 0 additions & 1 deletion python/tests/core/manifests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
class TestRecordManifestWriter:

def test_write(self, tmp_path):
data_dir = tmp_path / "dataset" / "data"
metadata_dir = tmp_path / "dataset" / "metadata"
metadata_dir.mkdir(parents=True)

Expand Down
94 changes: 94 additions & 0 deletions python/tests/core/ops/test_append.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
from tensorflow_datasets import features # type: ignore[import-untyped]

from space.core.ops import LocalAppendOp
import space.core.proto.metadata_pb2 as meta
from space.core.schema.types import TfFeatures
from space.core.storage import Storage


class TestLocalAppendOp:

# TODO: to add tests using Arrow table input.

def test_write_pydict_all_types(self, tmp_path):
location = tmp_path / "dataset"
schema = pa.schema([
Expand Down Expand Up @@ -75,6 +81,82 @@ def test_write_pydict_all_types(self, tmp_path):
assert patch.storage_statistics_update == meta.StorageStatistics(
num_rows=5, index_compressed_bytes=114, index_uncompressed_bytes=126)

def test_write_pydict_with_record_fields(self, tmp_path):
tf_features_images = features.FeaturesDict(
{"images": features.Image(shape=(None, None, 3), dtype=np.uint8)})
tf_features_objects = features.FeaturesDict({
"objects":
features.Sequence({
"bbox": features.BBoxFeature(),
"id": np.int64
}),
})

location = tmp_path / "dataset"
schema = pa.schema([
pa.field("int64", pa.int64()),
pa.field("string", pa.string()),
pa.field("images", TfFeatures(tf_features_images)),
pa.field("objects", TfFeatures(tf_features_objects))
])
storage = Storage.create(location=str(location),
schema=schema,
primary_keys=["int64"],
record_fields=["images", "objects"])

op = LocalAppendOp(str(location), storage.metadata)

op.write({
"int64": [1, 2, 3],
"string": ["a", "b", "c"],
"images": [b"images0", b"images1", b"images2"],
"objects": [b"objects0", b"objects1", b"objects2"]
})
op.write({
"int64": [0, 10],
"string": ["A", "z"],
"images": [b"images3", b"images4"],
"objects": [b"objects3", b"objects4"]
})

patch = op.finish()
assert patch is not None

# Validate index manifest files.
index_manifest = self._read_manifests(
storage, list(patch.added_index_manifest_files))
assert index_manifest == {
"_FILE": index_manifest["_FILE"],
"_INDEX_COMPRESSED_BYTES": [114],
"_INDEX_UNCOMPRESSED_BYTES": [126],
"_NUM_ROWS": [5],
"_STATS_f0": [{
"_MAX": 10,
"_MIN": 0
}]
}

# Validate record manifest files.
record_manifest = self._read_manifests(
storage, list(patch.added_record_manifest_files))
assert record_manifest == {
"_FILE": record_manifest["_FILE"],
"_FIELD_ID": [2, 3],
"_NUM_ROWS": [5, 5],
"_UNCOMPRESSED_BYTES": [55, 60]
}

# Data file exists.
self._check_file_exists(location, index_manifest["_FILE"])
self._check_file_exists(location, record_manifest["_FILE"])

# Validate statistics.
assert patch.storage_statistics_update == meta.StorageStatistics(
num_rows=5,
index_compressed_bytes=114,
index_uncompressed_bytes=126,
record_uncompressed_bytes=115)

def test_empty_op_return_none(self, tmp_path):
location = tmp_path / "dataset"
schema = pa.schema([pa.field("int64", pa.int64())])
Expand All @@ -85,3 +167,15 @@ def test_empty_op_return_none(self, tmp_path):

op = LocalAppendOp(str(location), storage.metadata)
assert op.finish() is None

def _read_manifests(self, storage: Storage,
file_paths: List[str]) -> pa.Table:
manifests = []
for f in file_paths:
manifests.append(pq.read_table(storage.full_path(f)))

return pa.concat_tables(manifests).to_pydict()

def _check_file_exists(self, location, file_paths: List[str]):
for f in file_paths:
assert (location / f).exists()
34 changes: 33 additions & 1 deletion python/tests/core/schema/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import json

import numpy as np
import pyarrow as pa
import pytest
from tensorflow_datasets import features # type: ignore[import-untyped]
from substrait.type_pb2 import NamedStruct, Type

from space.core.schema.arrow import field_metadata
from space.core.schema.types import TfFeatures


@pytest.fixture
Expand Down Expand Up @@ -92,3 +96,31 @@ def sample_arrow_schema():
]),
metadata=field_metadata(260))
])


@pytest.fixture
def tf_features():
return features.FeaturesDict(
{"images": features.Image(shape=(None, None, 3), dtype=np.uint8)})


@pytest.fixture
def tf_features_substrait_fields(tf_features): # pylint: disable=redefined-outer-name
return NamedStruct(
names=["int64", "features"],
struct=Type.Struct(types=[
Type(i64=Type.I64(type_variation_reference=0)),
Type(user_defined=Type.UserDefined(type_parameters=[
Type.Parameter(string="TF_FEATURES"),
Type.Parameter(string=json.dumps(tf_features.to_json()))
],
type_variation_reference=1))
]))


@pytest.fixture
def tf_features_arrow_schema(tf_features): # pylint: disable=redefined-outer-name
return pa.schema([
pa.field("int64", pa.int64(), metadata=field_metadata(0)),
pa.field("features", TfFeatures(tf_features), metadata=field_metadata(1))
])
30 changes: 28 additions & 2 deletions python/tests/core/schema/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pyarrow as pa

from space.core.schema import arrow
from space.core.schema.arrow import field_metadata


def test_field_metadata():
Expand All @@ -27,8 +28,33 @@ def test_field_id():
b"123"})) == 123


def test_arrow_schema(sample_substrait_fields, sample_arrow_schema):
assert sample_arrow_schema == arrow.arrow_schema(sample_substrait_fields)
def test_arrow_schema_logical_without_records(sample_substrait_fields,
sample_arrow_schema):
assert arrow.arrow_schema(sample_substrait_fields, [],
False) == sample_arrow_schema


def test_arrow_schema_logical_with_records(tf_features_substrait_fields,
tf_features_arrow_schema):
assert arrow.arrow_schema(tf_features_substrait_fields, [],
False) == tf_features_arrow_schema


def test_arrow_schema_physical_without_records(sample_substrait_fields,
sample_arrow_schema):
assert arrow.arrow_schema(sample_substrait_fields, [],
True) == sample_arrow_schema


def test_arrow_schema_physical_with_records(tf_features_substrait_fields):
arrow_schema = pa.schema([
pa.field("int64", pa.int64(), metadata=field_metadata(0)),
pa.field("features",
pa.struct([("_FILE", pa.string()), ("_ROW_ID", pa.int32())]),
metadata=field_metadata(1))
])
assert arrow.arrow_schema(tf_features_substrait_fields, ["features"],
True) == arrow_schema


def test_field_name_to_id_dict(sample_arrow_schema):
Expand Down
6 changes: 6 additions & 0 deletions python/tests/core/schema/types/test_tf_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,16 @@ def test_arrow_ext_serialize_deserialize(self, tf_features, sample_objects):
"type"] == "tensorflow_datasets.core.features.features_dict.FeaturesDict" # pylint: disable=line-too-long
assert "sequence" in features_dict["content"]["features"]["objects"]

# Bytes input.
tf_features = TfFeatures.__arrow_ext_deserialize__(storage_type=None,
serialized=serialized)
assert len(tf_features.serialize(sample_objects)) > 0

# String input.
tf_features = TfFeatures.__arrow_ext_deserialize__(
storage_type=None, serialized=serialized.decode(UTF_8))
assert len(tf_features.serialize(sample_objects)) > 0

def test_serialize_deserialize(self, tf_features, sample_objects):
value_bytes = tf_features.serialize(sample_objects)
assert len(value_bytes) > 0
Expand Down

0 comments on commit 6122156

Please sign in to comment.