Skip to content

Commit

Permalink
Add unit tests for field IDs, substrait, and TfFeatures
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Dec 19, 2023
1 parent 150fdb4 commit 2a60e92
Show file tree
Hide file tree
Showing 13 changed files with 475 additions and 113 deletions.
4 changes: 4 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ classifiers = [
]
requires-python = ">=3.8"
dependencies = [
"numpy",
"protobuf",
"pyarrow >= 14.0.0",
"tensorflow_datasets",
"typing_extensions"
]

[project.optional-dependencies]
dev = [
"pyarrow-stubs",
"tensorflow",
"types-protobuf"
]

Expand Down
9 changes: 5 additions & 4 deletions python/src/space/core/schema/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
"""Utilities for schemas in the Arrow format."""

from typing import Dict

import pyarrow as pa

_PARQUET_FIELD_ID_KEY = "PARQUET:field_id"
from space.core.utils.constants import UTF_8

_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"


def field_metadata(field_id_: int) -> Dict[str, str]:
def field_metadata(field_id_: int) -> Dict[bytes, bytes]:
"""Return Arrow field metadata for a field."""
return {_PARQUET_FIELD_ID_KEY: str(field_id_)}
return {_PARQUET_FIELD_ID_KEY: str(field_id_).encode(UTF_8)}


def field_id(field: pa.Field) -> int:
Expand Down
43 changes: 27 additions & 16 deletions python/src/space/core/schema/field_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,41 @@

from space.core.schema import arrow

_INIT_FIELD_ID = 0
# The start value of field IDs.
_START_FIELD_ID = 0


# pylint: disable=too-few-public-methods
class FieldIdManager:
"""Assign field IDs to schema fields."""
"""Assign field IDs to schema fields using Depth First Search.
Rules for nested fields:
- For a list field with ID i, its element is assigned i+1.
- For a struct field with ID i, its fields are assigned starting from i+1.
Not thread safe.
"""

def __init__(self, next_field_id: Optional[int] = None):
self._next_field_id = (next_field_id
if next_field_id is not None else _INIT_FIELD_ID)
if next_field_id is None:
self._next_field_id = _START_FIELD_ID
else:
assert next_field_id >= _START_FIELD_ID
self._next_field_id = next_field_id

def assign_field_ids(self, schema: pa.Schema) -> pa.Schema:
"""Return a new schema with field IDs assigned."""
return pa.schema(self._assign_field_ids(list(schema)))

def _assign_field_ids(self, fields: List[pa.Field]) -> List[pa.Field]:
return [self._assign_field_id(f) for f in fields]

def _assign_field_id(self, field: pa.Field) -> pa.Field:
this_field_id = self._next_field_id
metadata = arrow.field_metadata(this_field_id)
metadata = arrow.field_metadata(self._next_field_id)
self._next_field_id += 1

name = field.name
type_ = field.type
name, type_ = field.name, field.type

if pa.types.is_list(type_):
return pa.field(
name,
Expand All @@ -50,12 +67,6 @@ def _assign_field_id(self, field: pa.Field) -> pa.Field:
[type_.field(i) for i in range(type_.num_fields)]))
return pa.field(name, struct_type, metadata=metadata)

return field.with_metadata(metadata)

def _assign_field_ids(self, fields: List[pa.Field]) -> List[pa.Field]:
return [self._assign_field_id(f) for f in fields]
# TODO: to support more types, e.g., fixed_size_list, map.

def assign_field_ids(self, schema: pa.Schema) -> pa.Schema:
"""Assign field IDs to schema fields."""
return pa.schema(
self._assign_field_ids([schema.field(i) for i in range(len(schema))]))
return field.with_metadata(metadata) # type: ignore[arg-type]
39 changes: 21 additions & 18 deletions python/src/space/core/schema/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,40 @@
"""Utilities for schemas in the Substrait format."""

from __future__ import annotations
from typing import List
from typing import Any, List

import pyarrow as pa

from substrait.type_pb2 import NamedStruct, Type

import space.core.schema.arrow as arrow_schema
from space.core.schema.types import TfFeatures
from space.core.utils import constants
from space.core.utils.constants import UTF_8

# Substrait type name of Arrow custom type TfFeatures.
TF_FEATURES_TYPE = "TF_FEATURES"


def substrait_fields(schema: pa.Schema) -> NamedStruct:
"""Convert Arrow schema to equivalent Substrait fields."""
"""Convert Arrow schema to equivalent Substrait fields.
According to the Substrait spec, traverse schema fields in the Depth First
Search order. The field names are persisted in `mutable_names` in the same
order.
"""
mutable_names: List[str] = []
types = _substrait_fields([schema.field(i) for i in range(len(schema))],
mutable_names)
types = _substrait_fields(list(schema), mutable_names)
return NamedStruct(names=mutable_names, struct=Type.Struct(types=types))


def _substrait_fields(fields: List[pa.Field],
mutable_names: List[str]) -> List[Type]:
"""Convert a list of Arrow fields to Substrait types, and record field names.
"""
return [_substrait_field(f, mutable_names) for f in fields]


def _substrait_field(field: pa.Field,
mutable_names: List[str],
skip_name=False) -> Type:
"""Convert an Arrow fields to a Substrait type, and record its field name."""
if not skip_name:
is_list_item=False) -> Type:
if not is_list_item:
mutable_names.append(field.name)

type_ = Type()
Expand All @@ -74,24 +74,27 @@ def _substrait_field(field: pa.Field,
_substrait_field(
field.type.value_field, # type: ignore[attr-defined]
mutable_names,
skip_name=True))
# TODO: to support fixed_size_list in substrait.
is_list_item=True))
# TODO: to support more types in Substrait, e.g., fixed_size_list, map.
elif pa.types.is_struct(field.type):
_set_field_id(type_.struct, field_id)
subfields = [field.type.field(i) for i in range(field.type.num_fields)]
subfields = list(field.type) # type: ignore[call-overload]
type_.struct.types.extend(_substrait_fields(subfields, mutable_names))
elif isinstance(field.type, TfFeatures):
# TfFeatures is persisted in Substrait as a user defined type, with
# parameters [TF_FEATURES_TYPE, __arrow_ext_serialize__()].
_set_field_id(type_.user_defined, field_id)
type_.user_defined.type_parameters.extend([
Type.Parameter(string=TF_FEATURES_TYPE),
Type.Parameter(string=field.type.__arrow_ext_serialize__().decode(
constants.UTF_8))
Type.Parameter(
string=field.type.__arrow_ext_serialize__().decode(UTF_8))
])
else:
raise ValueError(f"Type is not supported: {field.type}")
raise ValueError(
f"Type {field.type} of field {field.name} is not supported")

return type_


def _set_field_id(msg, field_id: int) -> None:
def _set_field_id(msg: Any, field_id: int) -> None:
msg.type_variation_reference = field_id
43 changes: 24 additions & 19 deletions python/src/space/core/schema/types/tf_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,46 @@
"""Define a custom Arrow type for Tensorflow Dataset Features."""

from __future__ import annotations
from typing import Any, Union
from typing import Any

import json
import pyarrow as pa
import tensorflow_datasets as tfds # type: ignore[import-untyped]
from tensorflow_datasets import features as tf_features
from tensorflow_datasets import features as f # type: ignore[import-untyped]

from space.core.serializers import TypeSerializer
from space.core.utils import constants
from space.core.serializers import DeserializedData, FieldSerializer
from space.core.utils.constants import UTF_8


class TfFeatures(pa.ExtensionType, TypeSerializer):
class TfFeatures(pa.ExtensionType, FieldSerializer):
"""A custom Arrow type for Tensorflow Dataset Features."""

def __init__(self, features: tf_features.FeaturesDict):
self._features = features
self._serialized = json.dumps(features.to_json())
def __init__(self, features_dict: f.FeaturesDict):
"""
Args:
features_dict: a Tensorflow Dataset features dict providing serializers
for a nested dict of Tensors or Numpy arrays, see
https://www.tensorflow.org/datasets/api_docs/python/tfds/features/FeaturesDict
"""
self._features_dict = features_dict
self._serialized = json.dumps(features_dict.to_json())
pa.ExtensionType.__init__(self, pa.binary(), self._serialized)

def __arrow_ext_serialize__(self) -> bytes:
return self._serialized.encode(constants.UTF_8)
return self._serialized.encode(UTF_8)

@classmethod
def __arrow_ext_deserialize__(
cls,
storage_type: pa.DataType, # pylint: disable=unused-argument
serialized: Union[bytes, str]
) -> TfFeatures:
if isinstance(serialized, bytes):
serialized = serialized.decode(constants.UTF_8)

features = tf_features.FeaturesDict.from_json(json.loads(serialized))
return TfFeatures(features)
serialized: bytes) -> TfFeatures:
return TfFeatures(
f.FeaturesDict.from_json(json.loads(serialized.decode(UTF_8))))

def serialize(self, value: Any) -> bytes:
return self._features.serialize_example(value)
"""Serialize value using the provided features_dict."""
return self._features_dict.serialize_example(value)

def deserialize(self, value_bytes: bytes) -> Any:
return tfds.as_numpy(self._features.deserialize_example(value_bytes))
def deserialize(self, value_bytes: bytes) -> DeserializedData:
"""Deserialize value using the provided features_dict."""
return tfds.as_numpy(self._features_dict.deserialize_example(value_bytes))
2 changes: 1 addition & 1 deletion python/src/space/core/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
#
"""Serializers (and deserializers) for unstructured record fields."""

from space.core.serializers.base import DictSerializer, TypeSerializer
from space.core.serializers.base import DeserializedData, FieldSerializer
66 changes: 21 additions & 45 deletions python/src/space/core/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,33 @@
"""Serializers (and deserializers) for unstructured record fields."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any
from typing_extensions import TypeAlias

import pyarrow as pa
# pylint: disable=line-too-long
from tensorflow_datasets.core.dataset_utils import NumpyElem, Tree # type: ignore[import-untyped]

DictData = Dict[str, List[Any]]
DeserializedData: TypeAlias = Tree[NumpyElem]


class TypeSerializer(ABC):
"""Abstract serializer of a type."""

@abstractmethod
def serialize(self, value: Any) -> bytes:
"""Serialize a value."""
return NotImplemented

@abstractmethod
def deserialize(self, value_bytes: bytes) -> Any:
"""Deserialize bytes to a value."""
return NotImplemented


class DictSerializer:
"""A serializer for rows in PyDict format.
class FieldSerializer(ABC):
"""Abstract serializer of a field.
PyDict format has the layout {"field": [values...], ...}.
Used for serializing record fields into bytes to be stored in Space.
"""

def __init__(self, schema: pa.Schema):
self._serializers: Dict[str, TypeSerializer] = {}

for i in range(len(schema)):
field = schema.field(i)
if isinstance(field.type, TypeSerializer):
self._serializers[field.name] = field.type

def field_serializer(self, field: str) -> TypeSerializer:
"""Return the serializer for a given field."""
return self._serializers[field]

def serialize(self, batch: DictData) -> DictData:
"""Serialize a batch of rows."""
for name, ser in self._serializers.items():
if name in batch:
batch[name] = [ser.serialize(d) for d in batch[name]]

return batch
@abstractmethod
def serialize(self, value: Any) -> bytes:
"""Serialize a value.
def deserialize(self, batch: DictData) -> DictData:
"""Deserialize a batch of rows."""
for name, ser in self._serializers.items():
if name in batch:
batch[name] = [ser.deserialize(d) for d in batch[name]]
Args:
value: numpy-like nested dict.
"""

return batch
@abstractmethod
def deserialize(self, value_bytes: bytes) -> DeserializedData:
"""Deserialize bytes to a value.
Returns:
Numpy-like nested dict.
"""
5 changes: 3 additions & 2 deletions python/src/space/core/utils/protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from google.protobuf import message
from google.protobuf import text_format
from google.protobuf.timestamp_pb2 import Timestamp
from space.core.utils import constants

from space.core.utils.constants import UTF_8


def proto_to_text(msg: message.Message) -> bytes:
"""Return the text format of a proto."""
return text_format.MessageToString(msg).encode(constants.UTF_8)
return text_format.MessageToString(msg).encode(UTF_8)


def proto_now() -> Timestamp:
Expand Down
27 changes: 27 additions & 0 deletions python/tests/core/schema/test_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pyarrow as pa

import space.core.schema.arrow as arrow_schema


def test_field_metadata():
assert arrow_schema.field_metadata(123) == {b"PARQUET:field_id": b"123"}


def test_field_id():
assert arrow_schema.field_id(
pa.field("name", pa.int64(), metadata={b"PARQUET:field_id":
b"123"})) == 123
Loading

0 comments on commit 2a60e92

Please sign in to comment.