Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image & Audio formatting for numpy/torch/tf/jax #5072

Merged
merged 26 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,54 @@ To get a single tensor, you must explicitly use the [`Array`] feature type and s

```py
>>> from datasets import Dataset, Features, ClassLabel
>>> data = [0, 0, 1]
>>> features = Features({"data": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> labels = [0, 0, 1]
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"label": labels}, features=features)
>>> ds = ds.with_format("torch")
>>> ds[:3]
{'data': tensor([0, 0, 1])}
{'label': tensor([0, 0, 1])}
```

However, since it's not possible to convert text data to PyTorch tensors, you can't format a `string` column to PyTorch.
Instead, you can explicitly format certain columns and leave the other columns unformatted:
String and binary objects are unchanged, since PyTorch only supports numbers.

The [`Image`] and [`Audio`] feature types are also supported:
mariosasko marked this conversation as resolved.
Show resolved Hide resolved

```py
>>> from datasets import Dataset, Features
>>> text = ["foo", "bar"]
>>> data = [0, 1]
>>> ds = Dataset.from_dict({"text": text, "data": data})
>>> ds = ds.with_format("torch", columns=["data"], output_all_columns=True)
>>> ds[:2]
{'data': tensor([0, 1]), 'text': ['foo', 'bar']}
>>> from datasets import Dataset, Features, Audio, Image
>>> images = ["path/to/image.png"] * 10
>>> features = Features({"image": Image()})
>>> ds = Dataset.from_dict({"image": images}, features=features)
>>> ds = ds.with_format("torch")
>>> ds[0]["image"].shape
torch.Size([512, 512, 4])
>>> ds[0]
{'image': tensor([[[255, 215, 106, 255],
[255, 215, 106, 255],
...,
[255, 255, 255, 255],
[255, 255, 255, 255]]], dtype=torch.uint8)}
>>> ds[:2]["image"].shape
torch.Size([2, 512, 512, 4])
>>> ds[:2]
{'image': tensor([[[[255, 215, 106, 255],
[255, 215, 106, 255],
...,
[255, 255, 255, 255],
[255, 255, 255, 255]]]], dtype=torch.uint8)}
```

The [`Image`] and [`Audio`] feature types are not supported yet.
```py
>>> from datasets import Dataset, Features, Audio, Image
>>> audio = ["path/to/audio.wav"] * 10
>>> features = Features({"audio": Audio()})
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
>>> ds = ds.with_format("torch")
>>> ds[0]["audio"]["array"]
tensor([ 6.1035e-05, 1.5259e-05, 1.6785e-04, ..., -1.5259e-05,
-1.5259e-05, 1.5259e-05])
>>> ds[0]["audio"]["sampling_rate"]
tensor(44100)
```

## Data loading

Expand Down
50 changes: 44 additions & 6 deletions docs/source/use_with_tensorflow.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ To get a single tensor, you must explicitly use the Array feature type and speci

```py
>>> from datasets import Dataset, Features, ClassLabel
>>> data = [0, 0, 1]
>>> features = Features({"data": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> labels = [0, 0, 1]
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"label": labels}, features=features)
>>> ds = ds.with_format("tf")
>>> ds[:3]
{'data': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>
{'label': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 0, 1])>}
```

Strings are also supported:
Strings and binary objects are also supported:

```py
>>> from datasets import Dataset, Features
Expand All @@ -111,7 +111,45 @@ You can also explicitly format certain columns and leave the other columns unfor
'text': ['foo', 'bar']}
```

The [`Image`] and [`Audio`] feature types are not supported yet.
String and binary objects are unchanged, since PyTorch only supports numbers.

The [`Image`] and [`Audio`] feature types are also supported:

```py
>>> from datasets import Dataset, Features, Audio, Image
>>> images = ["path/to/image.png"] * 10
>>> features = Features({"image": Image()})
>>> ds = Dataset.from_dict({"image": images}, features=features)
>>> ds = ds.with_format("tf")
>>> ds[0]
{'image': <tf.Tensor: shape=(512, 512, 4), dtype=uint8, numpy=
array([[[255, 215, 106, 255],
[255, 215, 106, 255],
...,
[255, 255, 255, 255],
[255, 255, 255, 255]]], dtype=uint8)>}
>>> ds[:2]
{'image': <tf.Tensor: shape=(2, 512, 512, 4), dtype=uint8, numpy=
array([[[[255, 215, 106, 255],
[255, 215, 106, 255],
...,
[255, 255, 255, 255],
[255, 255, 255, 255]]]], dtype=uint8)>}
```

```py
>>> from datasets import Dataset, Features, Audio, Image
>>> audio = ["path/to/audio.wav"] * 10
>>> features = Features({"audio": Audio()})
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
>>> ds = ds.with_format("tf")
>>> ds[0]["audio"]["array"]
<tf.Tensor: shape=(202311,), dtype=float32, numpy=
array([ 6.1035156e-05, 1.5258789e-05, 1.6784668e-04, ...,
-1.5258789e-05, -1.5258789e-05, 1.5258789e-05], dtype=float32)>
>>> ds[0]["audio"]["sampling_rate"]
<tf.Tensor: shape=(), dtype=int32, numpy=44100>
```

## Data loading

Expand Down
15 changes: 10 additions & 5 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3978,9 +3978,14 @@ def _int64_feature(values):
"""Returns an int64_list from a list of bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature":
def _feature(values: Union[float, int, str, np.ndarray, list]) -> "tf.train.Feature":
"""Typechecks `values` and returns the corresponding tf.train.Feature."""
if isinstance(values, np.ndarray):
if isinstance(values, list):
if values and isinstance(values[0], str):
return _bytes_feature([v.encode() for v in values])
else:
raise ValueError(f"values={values} is empty or contains items that cannot be serialized")
elif isinstance(values, np.ndarray):
if values.dtype == np.dtype(float):
return _float_feature(values)
elif values.dtype == np.int64:
Expand All @@ -3991,9 +3996,9 @@ def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature":
return _bytes_feature([v.encode() for v in values])
else:
raise ValueError(
f"values={values} is an np.ndarray with items of dtype {values[0].dtype}, which cannot be serialized"
f"values={values} is empty or is an np.ndarray with items of dtype {values[0].dtype}, which cannot be serialized"
)
if hasattr(values, "dtype"):
elif hasattr(values, "dtype"):
if np.issubdtype(values.dtype, np.floating):
return _float_feature([values.item()])
elif np.issubdtype(values.dtype, np.integer):
Expand All @@ -4003,7 +4008,7 @@ def _feature(values: Union[float, int, str, np.ndarray]) -> "tf.train.Feature":
else:
raise ValueError(f"values={values} has dtype {values.dtype}, which cannot be serialized")
else:
raise ValueError(f"values={values} are not numpy objects, and so cannot be serialized")
raise ValueError(f"values={values} are not numpy objects or strings, and so cannot be serialized")

def serialize_example(ex):
feature = {key: _feature(value) for key, value in ex.items()}
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/features/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def encode_example(self, value: Union[str, dict, np.ndarray, "PIL.Image.Image"])
else:
raise ImportError("To support encoding images, please install 'Pillow'.")

if isinstance(value, list):
value = np.array(value)

if isinstance(value, str):
return {"path": value, "bytes": None}
elif isinstance(value, np.ndarray):
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/formatting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
ArrowFormatter,
CustomFormatter,
Formatter,
NumpyFormatter,
PandasFormatter,
PythonFormatter,
format_table,
query_table,
)
from .np_formatter import NumpyFormatter


logger = logging.get_logger(__name__)
Expand Down
33 changes: 5 additions & 28 deletions src/datasets/formatting/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pandas as pd
import pyarrow as pa

from ..features import Features
from ..features.features import _ArrayXDExtensionType, _is_zero_copy_only, decode_nested_example, pandas_types_mapper
from ..table import Table
from ..utils.py_utils import no_op_if_value_is_null
Expand Down Expand Up @@ -198,8 +199,8 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
or (isinstance(x, float) and np.isnan(x))
for x in array
):
return np.array(array, copy=False, **{**self.np_array_kwargs, "dtype": object})
return np.array(array, copy=False, **self.np_array_kwargs)
return np.array(array, copy=False, dtype=object)
return np.array(array, copy=False)


class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):
Expand All @@ -214,7 +215,7 @@ def extract_batch(self, pa_table: pa.Table) -> pd.DataFrame:


class PythonFeaturesDecoder:
def __init__(self, features):
def __init__(self, features: Features):
self.features = features

def decode_row(self, row: dict) -> dict:
Expand All @@ -228,7 +229,7 @@ def decode_batch(self, batch: dict) -> dict:


class PandasFeaturesDecoder:
def __init__(self, features):
def __init__(self, features: Features):
self.features = features

def decode_row(self, row: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -325,30 +326,6 @@ def format_batch(self, pa_table: pa.Table) -> dict:
return batch


class NumpyFormatter(Formatter[dict, np.ndarray, dict]):
def __init__(self, features=None, decoded=True, **np_array_kwargs):
super().__init__(features=features, decoded=decoded)
self.np_array_kwargs = np_array_kwargs

def format_row(self, pa_table: pa.Table) -> dict:
row = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_row(pa_table)
if self.decoded:
row = self.python_features_decoder.decode_row(row)
return row

def format_column(self, pa_table: pa.Table) -> np.ndarray:
column = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_column(pa_table)
if self.decoded:
column = self.python_features_decoder.decode_column(column, pa_table.column_names[0])
return column

def format_batch(self, pa_table: pa.Table) -> dict:
batch = self.numpy_arrow_extractor(**self.np_array_kwargs).extract_batch(pa_table)
if self.decoded:
batch = self.python_features_decoder.decode_batch(batch)
return batch


class PandasFormatter(Formatter):
def format_row(self, pa_table: pa.Table) -> pd.DataFrame:
row = self.pandas_arrow_extractor().extract_row(pa_table)
Expand Down
53 changes: 44 additions & 9 deletions src/datasets/formatting/jax_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

# Lint as: python3
import sys
from typing import TYPE_CHECKING

import numpy as np
import pyarrow as pa

from .. import config
from ..utils.py_utils import map_nested
from .formatting import Formatter

Expand All @@ -28,47 +30,80 @@

class JaxFormatter(Formatter[dict, "jnp.ndarray", dict]):
def __init__(self, features=None, decoded=True, **jnp_array_kwargs):
super().__init__(features=features, decoded=decoded)
self.jnp_array_kwargs = jnp_array_kwargs
import jax.numpy as jnp # noqa import jax at initialization

def _consolidate(self, column):
import jax.numpy as jnp

if isinstance(column, list) and column:
if all(
isinstance(x, jnp.ndarray) and x.shape == column[0].shape and x.dtype == column[0].dtype
for x in column
):
return jnp.stack(column)
return column

def _tensorize(self, value):
import jax
import jax.numpy as jnp

if isinstance(value, (str, bytes, type(None))):
return value
elif isinstance(value, (np.character, np.ndarray)) and np.issubdtype(value.dtype, np.character):
return value.tolist()

default_dtype = {}
if np.issubdtype(value.dtype, np.integer):

if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
# the default int precision depends on the jax config
# see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
if jax.config.jax_enable_x64:
default_dtype = {"dtype": jnp.int64}
else:
default_dtype = {"dtype": jnp.int32}
elif np.issubdtype(value.dtype, np.floating):
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": jnp.float32}
elif config.PIL_AVAILABLE and "PIL" in sys.modules:
import PIL.Image

if isinstance(value, PIL.Image.Image):
value = np.asarray(value)

# calling jnp.array on a np.ndarray does copy the data
# see https://github.com/google/jax/issues/4486
return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})

def _recursive_tensorize(self, data_struct: dict):
# support for nested types like struct of list of struct
if isinstance(data_struct, (list, np.ndarray)):
data_struct = np.array(data_struct, copy=False)
if isinstance(data_struct, np.ndarray):
if data_struct.dtype == object: # jax arrays cannot be instantied from an array of objects
return [self.recursive_tensorize(substruct) for substruct in data_struct]
return self._consolidate([self.recursive_tensorize(substruct) for substruct in data_struct])
return self._tensorize(data_struct)

def recursive_tensorize(self, data_struct: dict):
return map_nested(self._recursive_tensorize, data_struct, map_list=False)
return map_nested(self._recursive_tensorize, data_struct)

def format_row(self, pa_table: pa.Table) -> dict:
row = self.numpy_arrow_extractor().extract_row(pa_table)
if self.decoded:
row = self.python_features_decoder.decode_row(row)
return self.recursive_tensorize(row)

def format_column(self, pa_table: pa.Table) -> "jnp.ndarray":
col = self.numpy_arrow_extractor().extract_column(pa_table)
return self.recursive_tensorize(col)
column = self.numpy_arrow_extractor().extract_column(pa_table)
if self.decoded:
column = self.python_features_decoder.decode_column(column, pa_table.column_names[0])
column = self.recursive_tensorize(column)
column = self._consolidate(column)
return column

def format_batch(self, pa_table: pa.Table) -> dict:
batch = self.numpy_arrow_extractor().extract_batch(pa_table)
return self.recursive_tensorize(batch)
if self.decoded:
batch = self.python_features_decoder.decode_batch(batch)
batch = self.recursive_tensorize(batch)
for column_name in batch:
batch[column_name] = self._consolidate(batch[column_name])
return batch
Loading