-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add schema support: field ID, substrait, and TfFeatures Arrow field
- Loading branch information
coufon
committed
Dec 18, 2023
1 parent
58bd25c
commit 150fdb4
Showing
12 changed files
with
398 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Utilities for schema.""" | ||
|
||
from space.core.schema.field_ids import FieldIdManager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Utilities for schemas in the Arrow format.""" | ||
|
||
from typing import Dict | ||
|
||
import pyarrow as pa | ||
|
||
_PARQUET_FIELD_ID_KEY = "PARQUET:field_id" | ||
|
||
|
||
def field_metadata(field_id_: int) -> Dict[str, str]: | ||
"""Return Arrow field metadata for a field.""" | ||
return {_PARQUET_FIELD_ID_KEY: str(field_id_)} | ||
|
||
|
||
def field_id(field: pa.Field) -> int: | ||
"""Return field ID of an Arrow field.""" | ||
return int(field.metadata[_PARQUET_FIELD_ID_KEY]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Utilities for schema field IDs.""" | ||
|
||
from typing import List, Optional | ||
import pyarrow as pa | ||
|
||
from space.core.schema import arrow | ||
|
||
_INIT_FIELD_ID = 0 | ||
|
||
|
||
# pylint: disable=too-few-public-methods | ||
class FieldIdManager: | ||
"""Assign field IDs to schema fields.""" | ||
|
||
def __init__(self, next_field_id: Optional[int] = None): | ||
self._next_field_id = (next_field_id | ||
if next_field_id is not None else _INIT_FIELD_ID) | ||
|
||
def _assign_field_id(self, field: pa.Field) -> pa.Field: | ||
this_field_id = self._next_field_id | ||
metadata = arrow.field_metadata(this_field_id) | ||
self._next_field_id += 1 | ||
|
||
name = field.name | ||
type_ = field.type | ||
if pa.types.is_list(type_): | ||
return pa.field( | ||
name, | ||
pa.list_(self._assign_field_id( | ||
type_.value_field)), # type: ignore[attr-defined] | ||
metadata=metadata) | ||
|
||
if pa.types.is_struct(type_): | ||
struct_type = pa.struct( | ||
self._assign_field_ids( | ||
[type_.field(i) for i in range(type_.num_fields)])) | ||
return pa.field(name, struct_type, metadata=metadata) | ||
|
||
return field.with_metadata(metadata) | ||
|
||
def _assign_field_ids(self, fields: List[pa.Field]) -> List[pa.Field]: | ||
return [self._assign_field_id(f) for f in fields] | ||
|
||
def assign_field_ids(self, schema: pa.Schema) -> pa.Schema: | ||
"""Assign field IDs to schema fields.""" | ||
return pa.schema( | ||
self._assign_field_ids([schema.field(i) for i in range(len(schema))])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Utilities for schemas in the Substrait format.""" | ||
|
||
from __future__ import annotations | ||
from typing import List | ||
|
||
import pyarrow as pa | ||
|
||
from substrait.type_pb2 import NamedStruct, Type | ||
|
||
import space.core.schema.arrow as arrow_schema | ||
from space.core.schema.types import TfFeatures | ||
from space.core.utils import constants | ||
|
||
# Substrait type name of Arrow custom type TfFeatures. | ||
TF_FEATURES_TYPE = "TF_FEATURES" | ||
|
||
|
||
def substrait_fields(schema: pa.Schema) -> NamedStruct: | ||
"""Convert Arrow schema to equivalent Substrait fields.""" | ||
mutable_names: List[str] = [] | ||
types = _substrait_fields([schema.field(i) for i in range(len(schema))], | ||
mutable_names) | ||
return NamedStruct(names=mutable_names, struct=Type.Struct(types=types)) | ||
|
||
|
||
def _substrait_fields(fields: List[pa.Field], | ||
mutable_names: List[str]) -> List[Type]: | ||
"""Convert a list of Arrow fields to Substrait types, and record field names. | ||
""" | ||
return [_substrait_field(f, mutable_names) for f in fields] | ||
|
||
|
||
def _substrait_field(field: pa.Field, | ||
mutable_names: List[str], | ||
skip_name=False) -> Type: | ||
"""Convert an Arrow fields to a Substrait type, and record its field name.""" | ||
if not skip_name: | ||
mutable_names.append(field.name) | ||
|
||
type_ = Type() | ||
field_id = arrow_schema.field_id(field) | ||
|
||
if pa.types.is_int64(field.type): | ||
_set_field_id(type_.i64, field_id) | ||
elif pa.types.is_int32(field.type): | ||
_set_field_id(type_.i32, field_id) | ||
elif pa.types.is_string(field.type): | ||
_set_field_id(type_.string, field_id) | ||
elif pa.types.is_binary(field.type): | ||
_set_field_id(type_.binary, field_id) | ||
elif pa.types.is_boolean(field.type): | ||
_set_field_id(type_.bool, field_id) | ||
elif pa.types.is_float64(field.type): | ||
_set_field_id(type_.fp64, field_id) | ||
elif pa.types.is_float32(field.type): | ||
_set_field_id(type_.fp32, field_id) | ||
elif pa.types.is_list(field.type): | ||
_set_field_id(type_.list, field_id) | ||
type_.list.type.CopyFrom( | ||
_substrait_field( | ||
field.type.value_field, # type: ignore[attr-defined] | ||
mutable_names, | ||
skip_name=True)) | ||
# TODO: to support fixed_size_list in substrait. | ||
elif pa.types.is_struct(field.type): | ||
_set_field_id(type_.struct, field_id) | ||
subfields = [field.type.field(i) for i in range(field.type.num_fields)] | ||
type_.struct.types.extend(_substrait_fields(subfields, mutable_names)) | ||
elif isinstance(field.type, TfFeatures): | ||
_set_field_id(type_.user_defined, field_id) | ||
type_.user_defined.type_parameters.extend([ | ||
Type.Parameter(string=TF_FEATURES_TYPE), | ||
Type.Parameter(string=field.type.__arrow_ext_serialize__().decode( | ||
constants.UTF_8)) | ||
]) | ||
else: | ||
raise ValueError(f"Type is not supported: {field.type}") | ||
|
||
return type_ | ||
|
||
|
||
def _set_field_id(msg, field_id: int) -> None: | ||
msg.type_variation_reference = field_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Defines custom types.""" | ||
|
||
from space.core.schema.types.tf_features import TfFeatures |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Define a custom Arrow type for Tensorflow Dataset Features.""" | ||
|
||
from __future__ import annotations | ||
from typing import Any, Union | ||
import json | ||
import pyarrow as pa | ||
import tensorflow_datasets as tfds # type: ignore[import-untyped] | ||
from tensorflow_datasets import features as tf_features | ||
|
||
from space.core.serializers import TypeSerializer | ||
from space.core.utils import constants | ||
|
||
|
||
class TfFeatures(pa.ExtensionType, TypeSerializer): | ||
"""A custom Arrow type for Tensorflow Dataset Features.""" | ||
|
||
def __init__(self, features: tf_features.FeaturesDict): | ||
self._features = features | ||
self._serialized = json.dumps(features.to_json()) | ||
pa.ExtensionType.__init__(self, pa.binary(), self._serialized) | ||
|
||
def __arrow_ext_serialize__(self) -> bytes: | ||
return self._serialized.encode(constants.UTF_8) | ||
|
||
@classmethod | ||
def __arrow_ext_deserialize__( | ||
cls, | ||
storage_type: pa.DataType, # pylint: disable=unused-argument | ||
serialized: Union[bytes, str] | ||
) -> TfFeatures: | ||
if isinstance(serialized, bytes): | ||
serialized = serialized.decode(constants.UTF_8) | ||
|
||
features = tf_features.FeaturesDict.from_json(json.loads(serialized)) | ||
return TfFeatures(features) | ||
|
||
def serialize(self, value: Any) -> bytes: | ||
return self._features.serialize_example(value) | ||
|
||
def deserialize(self, value_bytes: bytes) -> Any: | ||
return tfds.as_numpy(self._features.deserialize_example(value_bytes)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Serializers (and deserializers) for unstructured record fields.""" | ||
|
||
from space.core.serializers.base import DictSerializer, TypeSerializer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Serializers (and deserializers) for unstructured record fields.""" | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, List | ||
|
||
import pyarrow as pa | ||
|
||
DictData = Dict[str, List[Any]] | ||
|
||
|
||
class TypeSerializer(ABC): | ||
"""Abstract serializer of a type.""" | ||
|
||
@abstractmethod | ||
def serialize(self, value: Any) -> bytes: | ||
"""Serialize a value.""" | ||
return NotImplemented | ||
|
||
@abstractmethod | ||
def deserialize(self, value_bytes: bytes) -> Any: | ||
"""Deserialize bytes to a value.""" | ||
return NotImplemented | ||
|
||
|
||
class DictSerializer: | ||
"""A serializer for rows in PyDict format. | ||
PyDict format has the layout {"field": [values...], ...}. | ||
""" | ||
|
||
def __init__(self, schema: pa.Schema): | ||
self._serializers: Dict[str, TypeSerializer] = {} | ||
|
||
for i in range(len(schema)): | ||
field = schema.field(i) | ||
if isinstance(field.type, TypeSerializer): | ||
self._serializers[field.name] = field.type | ||
|
||
def field_serializer(self, field: str) -> TypeSerializer: | ||
"""Return the serializer for a given field.""" | ||
return self._serializers[field] | ||
|
||
def serialize(self, batch: DictData) -> DictData: | ||
"""Serialize a batch of rows.""" | ||
for name, ser in self._serializers.items(): | ||
if name in batch: | ||
batch[name] = [ser.serialize(d) for d in batch[name]] | ||
|
||
return batch | ||
|
||
def deserialize(self, batch: DictData) -> DictData: | ||
"""Deserialize a batch of rows.""" | ||
for name, ser in self._serializers.items(): | ||
if name in batch: | ||
batch[name] = [ser.deserialize(d) for d in batch[name]] | ||
|
||
return batch |
Oops, something went wrong.