Skip to content

Commit

Permalink
[Datasets] Change read_tfrecords output from Pandas to Arrow format (
Browse files Browse the repository at this point in the history
…ray-project#30390)

This is to change read_tfrecords output from Pandas to Arrow format. From benchmark ray-project#30389, found the read_tfrecords is signigicantly slower than write_tfrecords.
  • Loading branch information
c21 authored Nov 21, 2022
1 parent 43a3028 commit 5f54406
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
1 change: 1 addition & 0 deletions doc/source/data/dataset-internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Different ways of creating Datasets leads to a different starting internal forma
* Reading tabular files (Parquet, CSV, JSON) creates Arrow blocks initially.
* Converting from Pandas, Dask, Modin, and Mars creates Pandas blocks initially.
* Reading NumPy files or converting from NumPy ndarrays creates Arrow blocks.
* Reading TFRecord file creates Arrow blocks.

However, this internal format is not exposed to the user. Datasets converts between formats
as needed internally depending on the specified ``batch_format`` of transformations.
21 changes: 17 additions & 4 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _read_stream(
self, f: "pyarrow.NativeFile", path: str, **reader_args
) -> Iterator[Block]:
from google.protobuf.message import DecodeError
import pandas as pd
import pyarrow as pa
import tensorflow as tf

for record in _read_records(f):
Expand All @@ -36,7 +36,7 @@ def _read_stream(
f"file contains a message type other than `tf.train.Example`: {e}"
)

yield pd.DataFrame([_convert_example_to_dict(example)])
yield pa.Table.from_pydict(_convert_example_to_dict(example))

def _write_block(
self,
Expand Down Expand Up @@ -64,13 +64,26 @@ def _write_block(

def _convert_example_to_dict(
example: "tf.train.Example",
) -> Dict[str, Union[bytes, List[bytes], float, List[float], int, List[int]]]:
) -> Dict[
str,
Union[
List[bytes],
List[List[bytes]],
List[float],
List[List[float]],
List[int],
List[List[int]],
],
]:
record = {}
for feature_name, feature in example.features.feature.items():
value = _get_feature_value(feature)
# Return value itself if the list has single value.
# This is to give better user experience when writing preprocessing UDF on
# these single-value lists.
if len(value) == 1:
value = value[0]
record[feature_name] = value
record[feature_name] = [value]
return record


Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/tests/test_dataset_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def test_read_tfrecords(ray_start_regular_shared, tmp_path):
"bytes_list": object,
}
assert list(df["int64"]) == [1]
assert list(df["int64_list"]) == [[1, 2, 3, 4]]
assert np.array_equal(df["int64_list"][0], np.array([1, 2, 3, 4]))
assert list(df["float"]) == [1.0]
assert list(df["float_list"]) == [[1.0, 2.0, 3.0, 4.0]]
assert np.array_equal(df["float_list"][0], np.array([1.0, 2.0, 3.0, 4.0]))
assert list(df["bytes"]) == [b"abc"]
assert list(df["bytes_list"]) == [[b"abc", b"1234"]]
assert np.array_equal(df["bytes_list"][0], np.array([b"abc", b"1234"]))


def test_write_tfrecords(ray_start_regular_shared, tmp_path):
Expand Down

0 comments on commit 5f54406

Please sign in to comment.