Skip to content

Commit

Permalink
PostprocessingDataset: add composition function
Browse files Browse the repository at this point in the history
  • Loading branch information
NeoLegends committed Aug 29, 2024
1 parent 448cffe commit 27437fe
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
60 changes: 59 additions & 1 deletion returnn/datasets/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from itertools import islice
from numpy.random import Generator, PCG64
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

from returnn.datasets.basic import DatasetSeq
from returnn.datasets.util.vocabulary import Vocabulary
Expand Down Expand Up @@ -54,6 +54,30 @@ class PostprocessingDataset(CachedDataset2):
"data": {"dims": [time_dim, new_data_dim]},
},
}
The dataset itself does not support its own seq ordering and relies on the wrapped
dataset for seq ordering instead. Specifying a ``seq_ordering`` other than ``default``
results in an error.
However, we provide an iterator that implements the common `laplace:.NUM_SEQS_PER_BIN`-variant
of seq ordering that any custom ``map_seq_stream``-style postprocessing iterator can be composed
with to implement the ordering.
Like this::
from returnn.datasets.postprocessing import compose, LaplaceOrdering
def my_map_seq_stream(iterator):
...
train = {
"class": "PostprocessingDataset",
# ...
"map_seq_stream": compose(
LaplaceOrdering(num_seqs_per_bin=1000),
my_map_seq_stream,
),
}
"""

def __init__(
Expand Down Expand Up @@ -228,10 +252,44 @@ def _make_tensor_template_from_input(self, data_key: str) -> Tensor:
return Tensor(data_key, dims=dims, dtype=dtype, sparse_dim=sparse_dim)


def compose(*postprocessing_funcs: Callable):
"""
Composes multiple postprocessing functions into one by sequential application,
i.e. compose(f, g)(x) = (f ∘ g)(x) = f(g(x)).
Can either compose ``map_seq``-style single-segment processor functions or ``map_seq_stream``-style
iterators operating on multiple segments.
The functions are applied in reverse order, i.e. last argument first.
:return: composite function applying :param:``postprocessing_funcs`` in reverse order.
"""

def wrapper(arg: Union[TensorDict, Iterator[TensorDict]], **kwargs) -> Union[TensorDict, Iterator[TensorDict]]:
"""composite postprocessing function"""

# If we are passed an iterator do not check for the concrete type (which may be some generator),
# but for parent `Iterator` type instead.
arg_type = Iterator if isinstance(arg, Iterator) else type(arg)

for i, callable in enumerate(reversed(postprocessing_funcs)):
arg = callable(arg, **kwargs)
assert isinstance(arg, arg_type), (
f"Function {i} returned a value of type {type(arg)}, "
f"but must return the same type as the input ({arg_type}) to ensure valid composition. "
"Did you mix map_seq and map_seq_stream-style postprocessing functions?"
)
return arg

return wrapper


class LaplaceOrdering(Callable[[Iterator[TensorDict]], Iterator[TensorDict]]):
"""
Iterator compatible with :class:`PostprocessingDataset`'s ``map_seq_stream`` applying
laplace sequence ordering based on the number of segments per bin.
To be composed with any custom data postprocessing logic via ``compose``.
"""

def __init__(self, num_seqs_per_bin: int, length_key: str = "data"):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,12 @@ def _repeat2(input_iter: Iterator[TensorDict], **kwargs) -> Iterator[TensorDict]
assert prev_len is None or classes.shape[0] <= prev_len or i == 3
prev_len = classes.shape[0]

# test composition
from returnn.datasets.postprocessing import compose

func = compose(lambda x: x * 10, lambda y: y + 1)
assert func(2) == 30


if __name__ == "__main__":
better_exchook.install()
Expand Down

0 comments on commit 27437fe

Please sign in to comment.