Skip to content

Commit

Permalink
FEAT: implement ChainedDataTransformer (#470)
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer authored Nov 6, 2022
1 parent 340004b commit 00ce851
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
"qrules",
"rightarrow",
"rtfd",
"rtol",
"scipy",
"sdist",
"seealso",
Expand Down
9 changes: 9 additions & 0 deletions src/tensorwaves/data/_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from typing import Iterable

from tensorwaves.interface import DataTransformer


def to_tuple(items: Iterable[DataTransformer]) -> tuple[DataTransformer, ...]:
return tuple(items)
30 changes: 30 additions & 0 deletions src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from typing import TYPE_CHECKING, Mapping

from attrs import field, frozen

from tensorwaves.function import PositionalArgumentFunction
from tensorwaves.function.sympy import (
_get_free_symbols, # pyright: ignore[reportPrivateUsage]
Expand All @@ -12,10 +14,38 @@
)
from tensorwaves.interface import DataSample, DataTransformer, Function

from ._attrs import to_tuple

if TYPE_CHECKING: # pragma: no cover
import sympy as sp


@frozen
class ChainedDataTransformer(DataTransformer):
"""Combine multiple `.DataTransformer` classes into one.
Args:
transformer: Ordered list of transformers that you want to chain.
extend: Set to `True` in order to keep keys of each output `.DataSample` and
collect them into the final, chained `.DataSample`.
"""

transformers: tuple[DataTransformer, ...] = field(converter=to_tuple)
extend: bool = True

def __call__(self, data: DataSample) -> DataSample:
new_data = dict(data)
weights = new_data.get("weights")
for transformer in self.transformers:
if self.extend:
new_data.update(transformer(new_data))
else:
new_data = transformer(new_data)
if weights is not None:
new_data["weights"] = weights
return new_data


class IdentityTransformer(DataTransformer):
"""`.DataTransformer` that leaves a `.DataSample` intact."""

Expand Down
41 changes: 39 additions & 2 deletions tests/data/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,47 @@
# pylint: disable=invalid-name
from __future__ import annotations

import numpy as np
import pytest
import sympy as sp
from numpy import sqrt

from tensorwaves.data import IdentityTransformer, SympyDataTransformer
from tensorwaves.data.transform import (
ChainedDataTransformer,
IdentityTransformer,
SympyDataTransformer,
)


class TestChainedDataTransformer:
@pytest.mark.parametrize("extend", [False, True])
def test_identity_chain(self, extend: bool):
x, y, v, w = sp.symbols("x y v w")
transform1 = _create_transformer({v: 2 * x - 5, w: -0.2 * y + 3})
transform2 = _create_transformer({x: 0.5 * (v + 5), y: 5 * (3 - w)})
chained_transform = ChainedDataTransformer([transform1, transform2], extend)
rng = np.random.default_rng(seed=0)
data = {"x": rng.uniform(size=100), "y": rng.uniform(size=100)}
transformed_data = chained_transform(data)
for key in data: # pylint: disable=consider-using-dict-items
np.testing.assert_allclose(data[key], transformed_data[key], rtol=1e-13)
if extend:
assert set(transformed_data) == {"x", "y", "v", "w"}
else:
assert set(transformed_data) == {"x", "y"}

def test_single_chain(self):
transform = IdentityTransformer()
chained_transform = ChainedDataTransformer([transform])
data = {
"x": np.ones(5),
"y": np.ones(5),
}
assert data == chained_transform(data)
assert data is not chained_transform(data) # DataSample returned as new dict


def _create_transformer(expressions: dict[sp.Symbol, sp.Expr]) -> SympyDataTransformer:
return SympyDataTransformer.from_sympy(expressions, backend="jax")


class TestIdentityTransformer:
Expand Down

0 comments on commit 00ce851

Please sign in to comment.