Skip to content

Commit

Permalink
Add schema support: field ID, substrait, and TfFeatures Arrow field
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Dec 18, 2023
1 parent 58bd25c commit 150fdb4
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ disable = ['fixme']

[tool.pylint.MAIN]
ignore = 'space/core/proto'
ignored-modules = ['space.core.proto', 'google.protobuf']
ignored-modules = ['space.core.proto', 'google.protobuf', 'substrait']
17 changes: 17 additions & 0 deletions python/src/space/core/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for schema."""

from space.core.schema.field_ids import FieldIdManager
31 changes: 31 additions & 0 deletions python/src/space/core/schema/arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for schemas in the Arrow format."""

from typing import Dict

import pyarrow as pa

_PARQUET_FIELD_ID_KEY = "PARQUET:field_id"


def field_metadata(field_id_: int) -> Dict[str, str]:
"""Return Arrow field metadata for a field."""
return {_PARQUET_FIELD_ID_KEY: str(field_id_)}


def field_id(field: pa.Field) -> int:
"""Return field ID of an Arrow field."""
return int(field.metadata[_PARQUET_FIELD_ID_KEY])
61 changes: 61 additions & 0 deletions python/src/space/core/schema/field_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for schema field IDs."""

from typing import List, Optional
import pyarrow as pa

from space.core.schema import arrow

_INIT_FIELD_ID = 0


# pylint: disable=too-few-public-methods
class FieldIdManager:
"""Assign field IDs to schema fields."""

def __init__(self, next_field_id: Optional[int] = None):
self._next_field_id = (next_field_id
if next_field_id is not None else _INIT_FIELD_ID)

def _assign_field_id(self, field: pa.Field) -> pa.Field:
this_field_id = self._next_field_id
metadata = arrow.field_metadata(this_field_id)
self._next_field_id += 1

name = field.name
type_ = field.type
if pa.types.is_list(type_):
return pa.field(
name,
pa.list_(self._assign_field_id(
type_.value_field)), # type: ignore[attr-defined]
metadata=metadata)

if pa.types.is_struct(type_):
struct_type = pa.struct(
self._assign_field_ids(
[type_.field(i) for i in range(type_.num_fields)]))
return pa.field(name, struct_type, metadata=metadata)

return field.with_metadata(metadata)

def _assign_field_ids(self, fields: List[pa.Field]) -> List[pa.Field]:
return [self._assign_field_id(f) for f in fields]

def assign_field_ids(self, schema: pa.Schema) -> pa.Schema:
"""Assign field IDs to schema fields."""
return pa.schema(
self._assign_field_ids([schema.field(i) for i in range(len(schema))]))
97 changes: 97 additions & 0 deletions python/src/space/core/schema/substrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utilities for schemas in the Substrait format."""

from __future__ import annotations
from typing import List

import pyarrow as pa

from substrait.type_pb2 import NamedStruct, Type

import space.core.schema.arrow as arrow_schema
from space.core.schema.types import TfFeatures
from space.core.utils import constants

# Substrait type name of Arrow custom type TfFeatures.
TF_FEATURES_TYPE = "TF_FEATURES"


def substrait_fields(schema: pa.Schema) -> NamedStruct:
"""Convert Arrow schema to equivalent Substrait fields."""
mutable_names: List[str] = []
types = _substrait_fields([schema.field(i) for i in range(len(schema))],
mutable_names)
return NamedStruct(names=mutable_names, struct=Type.Struct(types=types))


def _substrait_fields(fields: List[pa.Field],
mutable_names: List[str]) -> List[Type]:
"""Convert a list of Arrow fields to Substrait types, and record field names.
"""
return [_substrait_field(f, mutable_names) for f in fields]


def _substrait_field(field: pa.Field,
mutable_names: List[str],
skip_name=False) -> Type:
"""Convert an Arrow fields to a Substrait type, and record its field name."""
if not skip_name:
mutable_names.append(field.name)

type_ = Type()
field_id = arrow_schema.field_id(field)

if pa.types.is_int64(field.type):
_set_field_id(type_.i64, field_id)
elif pa.types.is_int32(field.type):
_set_field_id(type_.i32, field_id)
elif pa.types.is_string(field.type):
_set_field_id(type_.string, field_id)
elif pa.types.is_binary(field.type):
_set_field_id(type_.binary, field_id)
elif pa.types.is_boolean(field.type):
_set_field_id(type_.bool, field_id)
elif pa.types.is_float64(field.type):
_set_field_id(type_.fp64, field_id)
elif pa.types.is_float32(field.type):
_set_field_id(type_.fp32, field_id)
elif pa.types.is_list(field.type):
_set_field_id(type_.list, field_id)
type_.list.type.CopyFrom(
_substrait_field(
field.type.value_field, # type: ignore[attr-defined]
mutable_names,
skip_name=True))
# TODO: to support fixed_size_list in substrait.
elif pa.types.is_struct(field.type):
_set_field_id(type_.struct, field_id)
subfields = [field.type.field(i) for i in range(field.type.num_fields)]
type_.struct.types.extend(_substrait_fields(subfields, mutable_names))
elif isinstance(field.type, TfFeatures):
_set_field_id(type_.user_defined, field_id)
type_.user_defined.type_parameters.extend([
Type.Parameter(string=TF_FEATURES_TYPE),
Type.Parameter(string=field.type.__arrow_ext_serialize__().decode(
constants.UTF_8))
])
else:
raise ValueError(f"Type is not supported: {field.type}")

return type_


def _set_field_id(msg, field_id: int) -> None:
msg.type_variation_reference = field_id
17 changes: 17 additions & 0 deletions python/src/space/core/schema/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Defines custom types."""

from space.core.schema.types.tf_features import TfFeatures
55 changes: 55 additions & 0 deletions python/src/space/core/schema/types/tf_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Define a custom Arrow type for Tensorflow Dataset Features."""

from __future__ import annotations
from typing import Any, Union
import json
import pyarrow as pa
import tensorflow_datasets as tfds # type: ignore[import-untyped]
from tensorflow_datasets import features as tf_features

from space.core.serializers import TypeSerializer
from space.core.utils import constants


class TfFeatures(pa.ExtensionType, TypeSerializer):
"""A custom Arrow type for Tensorflow Dataset Features."""

def __init__(self, features: tf_features.FeaturesDict):
self._features = features
self._serialized = json.dumps(features.to_json())
pa.ExtensionType.__init__(self, pa.binary(), self._serialized)

def __arrow_ext_serialize__(self) -> bytes:
return self._serialized.encode(constants.UTF_8)

@classmethod
def __arrow_ext_deserialize__(
cls,
storage_type: pa.DataType, # pylint: disable=unused-argument
serialized: Union[bytes, str]
) -> TfFeatures:
if isinstance(serialized, bytes):
serialized = serialized.decode(constants.UTF_8)

features = tf_features.FeaturesDict.from_json(json.loads(serialized))
return TfFeatures(features)

def serialize(self, value: Any) -> bytes:
return self._features.serialize_example(value)

def deserialize(self, value_bytes: bytes) -> Any:
return tfds.as_numpy(self._features.deserialize_example(value_bytes))
17 changes: 17 additions & 0 deletions python/src/space/core/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Serializers (and deserializers) for unstructured record fields."""

from space.core.serializers.base import DictSerializer, TypeSerializer
71 changes: 71 additions & 0 deletions python/src/space/core/serializers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Serializers (and deserializers) for unstructured record fields."""

from abc import ABC, abstractmethod
from typing import Any, Dict, List

import pyarrow as pa

DictData = Dict[str, List[Any]]


class TypeSerializer(ABC):
"""Abstract serializer of a type."""

@abstractmethod
def serialize(self, value: Any) -> bytes:
"""Serialize a value."""
return NotImplemented

@abstractmethod
def deserialize(self, value_bytes: bytes) -> Any:
"""Deserialize bytes to a value."""
return NotImplemented


class DictSerializer:
"""A serializer for rows in PyDict format.
PyDict format has the layout {"field": [values...], ...}.
"""

def __init__(self, schema: pa.Schema):
self._serializers: Dict[str, TypeSerializer] = {}

for i in range(len(schema)):
field = schema.field(i)
if isinstance(field.type, TypeSerializer):
self._serializers[field.name] = field.type

def field_serializer(self, field: str) -> TypeSerializer:
"""Return the serializer for a given field."""
return self._serializers[field]

def serialize(self, batch: DictData) -> DictData:
"""Serialize a batch of rows."""
for name, ser in self._serializers.items():
if name in batch:
batch[name] = [ser.serialize(d) for d in batch[name]]

return batch

def deserialize(self, batch: DictData) -> DictData:
"""Deserialize a batch of rows."""
for name, ser in self._serializers.items():
if name in batch:
batch[name] = [ser.deserialize(d) for d in batch[name]]

return batch
Loading

0 comments on commit 150fdb4

Please sign in to comment.