Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Add optional tf_schema parameter to read_tfrecords / write_tfrecords methods #32857

Merged
merged 15 commits into from
Mar 15, 2023
Merged
3 changes: 3 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
from ray.data.grouped_dataset import GroupedDataset
from ray.data._internal.execution.interfaces import Executor, NodeIdStr
from ray.data._internal.torch_iterable_dataset import TorchTensorBatchType
from tensorflow_metadata.proto.v0 import schema_pb2


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -2568,6 +2569,7 @@ def write_tfrecords(
self,
path: str,
*,
tf_schema: Optional["schema_pb2.Schema"] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
try_create_dir: bool = True,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -2626,6 +2628,7 @@ def write_tfrecords(
try_create_dir=try_create_dir,
open_stream_args=arrow_open_stream_args,
block_path_provider=block_path_provider,
tf_schema=tf_schema,
)

@ConsumptionAPI
Expand Down
148 changes: 123 additions & 25 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Union, Iterable, Iterator
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Union,
Iterable,
Iterator,
)
import struct

import numpy as np
Expand All @@ -11,6 +20,7 @@
if TYPE_CHECKING:
import pyarrow
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2


@PublicAPI(stability="alpha")
Expand All @@ -25,6 +35,8 @@ def _read_stream(
import pyarrow as pa
import tensorflow as tf

tf_schema: Optional["schema_pb2.Schema"] = reader_args.get("tf_schema", None)

for record in _read_records(f, path):
example = tf.train.Example()
try:
Expand All @@ -36,13 +48,14 @@ def _read_stream(
f"file contains a message type other than `tf.train.Example`: {e}"
)

yield pa.Table.from_pydict(_convert_example_to_dict(example))
yield pa.Table.from_pydict(_convert_example_to_dict(example, tf_schema))

def _write_block(
self,
f: "pyarrow.NativeFile",
block: BlockAccessor,
writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
tf_schema: Optional["schema_pb2.Schema"] = None,
**writer_args,
) -> None:

Expand All @@ -55,7 +68,7 @@ def _write_block(
# so we must iterate through the rows of the block,
# serialize to tf.train.Example proto, and write to file.

examples = _convert_arrow_table_to_examples(arrow_table)
examples = _convert_arrow_table_to_examples(arrow_table, tf_schema)

# Write each example to the arrow file in the TFRecord format.
for example in examples:
Expand All @@ -64,79 +77,140 @@ def _write_block(

def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
) -> Dict[str, "pyarrow.Array"]:
record = {}
schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
if tf_schema is not None:
for schema_feature in tf_schema.feature:
schema_dict[schema_feature.name] = schema_feature.type

for feature_name, feature in example.features.feature.items():
record[feature_name] = _get_feature_value(feature)
if tf_schema is not None and feature_name not in schema_dict:
raise ValueError(
f"Found extra unexpected feature {feature_name} "
f"not in specified schema: {tf_schema}"
)
schema_feature_type = schema_dict.get(feature_name)
record[feature_name] = _get_feature_value(feature, schema_feature_type)
return record


def _convert_arrow_table_to_examples(
arrow_table: "pyarrow.Table",
tf_schema: Optional["schema_pb2.Schema"] = None,
) -> Iterable["tf.train.Example"]:
import tensorflow as tf

schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
if tf_schema is not None:
for schema_feature in tf_schema.feature:
schema_dict[schema_feature.name] = schema_feature.type

# Serialize each row[i] of the block to a tf.train.Example and yield it.
for i in range(arrow_table.num_rows):

# First, convert row[i] to a dictionary.
features: Dict[str, "tf.train.Feature"] = {}
for name in arrow_table.column_names:
features[name] = _value_to_feature(arrow_table[name][i])
if tf_schema is not None and name not in schema_dict:
raise ValueError(
f"Found extra unexpected feature {name} "
f"not in specified schema: {tf_schema}"
)
schema_feature_type = schema_dict.get(name)
features[name] = _value_to_feature(
arrow_table[name][i],
schema_feature_type,
)

# Convert the dictionary to an Example proto.
proto = tf.train.Example(features=tf.train.Features(feature=features))

yield proto


def _get_single_true_type(dct) -> str:
"""Utility function for getting the single key which has a `True` value in
a dict. Used to filter a dict of `{field_type: is_valid}` to get
the field type from a schema or data source."""
filtered_types = iter([_type for _type in dct if dct[_type]])
# In the case where there are no keys with a `True` value, return `None`
return next(filtered_types, None)


def _get_feature_value(
feature: "tf.train.Feature",
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
) -> "pyarrow.Array":
import pyarrow as pa
from tensorflow_metadata.proto.v0 import schema_pb2

underlying_feature_type = {
"bytes": feature.HasField("bytes_list"),
"float": feature.HasField("float_list"),
"int": feature.HasField("int64_list"),
}
# At most one of `bytes_list`, `float_list`, and `int64_list`
# should contain values. If none contain data, this indicates
# an empty feature value.
assert sum(bool(value) for value in underlying_feature_type.values()) <= 1

if schema_feature_type is not None:
# If a schema is specified, compare to the
specified_feature_type = {
"bytes": schema_feature_type == schema_pb2.FeatureType.BYTES,
"float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
"int": schema_feature_type == schema_pb2.FeatureType.INT,
}
und_type = _get_single_true_type(underlying_feature_type)
spec_type = _get_single_true_type(specified_feature_type)
if und_type is not None and und_type != spec_type:
raise ValueError(
"Schema field type mismatch during read: specified type is "
f"{spec_type}, but underlying type is {und_type}",
)
# Override the underlying value type with the type in the user-specified schema.
underlying_feature_type = specified_feature_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if schema_feature_type is not None, should we also change the logic of line 189-193 that users may not want to unbox the list to scalar, if they specify the schema, and vice versa? We shall check with the user on this behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! If schema_feature_type is provided, we don't want the auto type conversion. The type should strictly follow the ones defined in the schema. Thanks!


values = (
feature.HasField("int64_list"),
feature.HasField("float_list"),
feature.HasField("bytes_list"),
)
# At most one of `bytes_list`, `float_list`, and `int64_list` contains data.
# If none contain data, this indicates an empty feature value.
assert sum(bool(value) for value in values) <= 1

if feature.HasField("bytes_list"):
if underlying_feature_type["bytes"]:
value = feature.bytes_list.value
type_ = pa.binary()
elif feature.HasField("float_list"):
elif underlying_feature_type["float"]:
value = feature.float_list.value
type_ = pa.float32()
elif feature.HasField("int64_list"):
elif underlying_feature_type["int"]:
value = feature.int64_list.value
type_ = pa.int64()
else:
value = []
type_ = pa.null()
value = list(value)
if len(value) == 1:
if len(value) == 1 and schema_feature_type is None:
# Use the value itself if the features contains a single value.
# This is to give better user experience when writing preprocessing UDF on
# these single-value lists.
value = value[0]
else:
# If the feature value is empty, set the type to null for now
# to allow pyarrow to construct a valid Array; later, infer the
# type from other records which have non-empty values for the feature.
# If the feature value is empty and no type is specified in the user-provided
# schema, set the type to null for now to allow pyarrow to construct a valid
# Array; later, infer the type from other records which have non-empty values
# for the feature.
if len(value) == 0:
type_ = pa.null()
type_ = pa.list_(type_)
return pa.array([value], type=type_)


def _value_to_feature(
value: Union["pyarrow.Scalar", "pyarrow.Array"]
value: Union["pyarrow.Scalar", "pyarrow.Array"],
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
) -> "tf.train.Feature":
import tensorflow as tf
import pyarrow as pa
from tensorflow_metadata.proto.v0 import schema_pb2

if isinstance(value, pa.ListScalar):
# Use the underlying type of the ListScalar's value in
Expand All @@ -151,11 +225,35 @@ def _value_to_feature(
else:
value = [value]

if pa.types.is_integer(value_type):
underlying_value_type = {
"bytes": pa.types.is_binary(value_type),
"float": pa.types.is_floating(value_type),
"int": pa.types.is_integer(value_type),
}
assert sum(bool(value) for value in underlying_value_type.values()) <= 1

if schema_feature_type is not None:
specified_feature_type = {
"bytes": schema_feature_type == schema_pb2.FeatureType.BYTES,
"float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
"int": schema_feature_type == schema_pb2.FeatureType.INT,
}

und_type = _get_single_true_type(underlying_value_type)
spec_type = _get_single_true_type(specified_feature_type)
if und_type is not None and und_type != spec_type:
raise ValueError(
"Schema field type mismatch during write: specified type is "
f"{spec_type}, but underlying type is {und_type}",
)
# Override the underlying value type with the type in the user-specified schema.
underlying_value_type = specified_feature_type

if underlying_value_type["int"]:
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
if pa.types.is_floating(value_type):
if underlying_value_type["float"]:
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
if pa.types.is_binary(value_type):
if underlying_value_type["bytes"]:
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
if pa.types.is_null(value_type):
raise ValueError(
Expand Down
5 changes: 5 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import pymongoarrow.api
import tensorflow as tf
import torch
from tensorflow_metadata.proto.v0 import schema_pb2


T = TypeVar("T")
Expand Down Expand Up @@ -1052,6 +1053,7 @@ def read_tfrecords(
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(),
partition_filter: Optional[PathPartitionFilter] = None,
tf_schema: Optional["schema_pb2.Schema"] = None,
) -> Dataset[PandasRow]:
"""Create a dataset from TFRecord files that contain
`tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/Example>`_
Expand Down Expand Up @@ -1119,6 +1121,8 @@ def read_tfrecords(
with a custom callback to read only selected partitions of a dataset.
By default, this filters out any file paths whose file extension does not
match ``"*.tfrecords*"``.
tf_schema: Optional TensorFlow Schema which is used to explicitly set the schema
of the underlying Dataset.

Returns:
A :class:`~ray.data.Dataset` that contains the example features.
Expand All @@ -1134,6 +1138,7 @@ def read_tfrecords(
open_stream_args=arrow_open_stream_args,
meta_provider=meta_provider,
partition_filter=partition_filter,
tf_schema=tf_schema,
)


Expand Down
Loading