Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PostprocessingDataset: add composition function #1609

Merged
merged 13 commits into from
Aug 30, 2024
58 changes: 55 additions & 3 deletions returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from itertools import islice
from numpy.random import RandomState
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

from returnn.datasets.basic import DatasetSeq
from returnn.datasets.util.vocabulary import Vocabulary
Expand All @@ -15,7 +15,7 @@
from .basic import init_dataset
from .cached2 import CachedDataset2

__all__ = ["PostprocessingDataset"]
__all__ = ["PostprocessingDataset", "LaplaceOrdering", "Sequential"]


class PostprocessingDataset(CachedDataset2):
Expand Down Expand Up @@ -54,6 +54,30 @@ class PostprocessingDataset(CachedDataset2):
"data": {"dims": [time_dim, new_data_dim]},
},
}

The dataset itself does not support its own seq ordering and relies on the wrapped
dataset for seq ordering instead. Specifying a ``seq_ordering`` other than ``default``
results in an error.

However, we provide an iterator that implements the common `laplace:.NUM_SEQS_PER_BIN`-variant
of seq ordering that any custom ``map_seq_stream``-style postprocessing iterator can be composed
with to implement the ordering via :class:`LaplaceOrdering`.

Like this::

from returnn.datasets.postprocessing import LaplaceOrdering, Sequential

def my_map_seq_stream(iterator):
...

train = {
"class": "PostprocessingDataset",
# ...
"map_seq_stream": Sequential(
my_map_seq_stream,
LaplaceOrdering(num_seqs_per_bin=1000),
),
}
"""

def __init__(
Expand Down Expand Up @@ -104,7 +128,7 @@ def __init__(

self._dataset = init_dataset(self._dataset_def, parent_dataset=self)
if self._map_seq_stream is None:
# if the stream mapper is set, the num_seqs may change and the estimation is less accurate
# if the stream mapper is set, the num_seqs may change and the estimation is less acxcurate
self._estimated_num_seqs = self._dataset.estimated_num_seqs
self._data_iter: Optional[Iterator[Tuple[int, TensorDict]]] = None

Expand Down Expand Up @@ -233,6 +257,8 @@ class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]):
"""
Iterator compatible with :class:`PostprocessingDataset`'s ``map_seq_stream`` applying
laplace sequence ordering based on the number of segments per bin.

To be composed with any custom data postprocessing logic via :class:`Sequential`.
"""

def __init__(self, num_seqs_per_bin: int, length_key: str = "data"):
Expand Down Expand Up @@ -264,3 +290,29 @@ def _get_seq_len(self, tdict: TensorDict) -> int:
:return: segment length of the segment in `tdict` as measured by `self.length_key` for comparison.
"""
return tdict.data[self.length_key].raw_tensor.shape[0]


T = TypeVar("T", TensorDict, Iterator[TensorDict])


class Sequential:
"""
Callable that composes multiple postprocessing functions into one by sequential application,
i.e. Sequential(f, g)(x) = (g ∘ f)(x) = g(f(x)).

Can either compose ``map_seq``-style single-segment processor functions or ``map_seq_stream``-style
iterators operating on multiple segments. Just make sure not to mix both styles.
"""

def __init__(self, *postprocessing_funcs: Callable):
NeoLegends marked this conversation as resolved.
Show resolved Hide resolved
"""
:param postprocessing_funcs: Postprocessing functions to compose.
"""
self.funcs = postprocessing_funcs

def __call__(self, arg: T, **kwargs) -> T:
""":return: result of sequential application of the postprocessing functions"""

for func in self.funcs:
arg = func(arg, **kwargs)
return arg
6 changes: 6 additions & 0 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,12 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]
assert prev_len is None or classes.shape[0] <= prev_len or i == 3
prev_len = classes.shape[0]

# test composition
from returnn.datasets.postprocessing import Sequential

func = Sequential(lambda x: x * 10, lambda y: y + 1)
assert func(2) == 21


if __name__ == "__main__":
better_exchook.install()
Expand Down
Loading