From 3e50f2d49dfaa288a42efa77b2cb95e23d7a4c02 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Thu, 2 Dec 2021 09:39:59 +0100 Subject: [PATCH] fix: lambdify ComplexSqrt to TensorFlow (#365) * chore: define custom _TensorflowPrinter * chore: extract _replace_module function * chore: forward JAX ComplexSqrt printer to numpycode * chore: move _backend to function sub-module * feat: expand available functions in TF printer * fix: improve find_function exception message * fix: lambdify ComplexSqrt to TensorFlow * test: lambdify ComplexSqrt with different backends * test: try more functions in test_find_function --- .flake8 | 2 + pytest.ini | 1 + src/tensorwaves/estimator.py | 2 +- src/tensorwaves/{ => function}/_backend.py | 17 +++++-- src/tensorwaves/function/sympy.py | 59 +++++++++++++--------- tests/unit/function/test_ampform.py | 30 +++++++++++ tests/unit/function/test_backend.py | 25 +++++++++ tests/unit/test_backend.py | 14 ----- 8 files changed, 106 insertions(+), 44 deletions(-) rename src/tensorwaves/{ => function}/_backend.py (80%) create mode 100644 tests/unit/function/test_ampform.py create mode 100644 tests/unit/function/test_backend.py delete mode 100644 tests/unit/test_backend.py diff --git a/.flake8 b/.flake8 index 460d773d..e100db1e 100644 --- a/.flake8 +++ b/.flake8 @@ -33,6 +33,8 @@ ignore = extend-select = TI100 per-file-ignores = + # printer methods + src/tensorwaves/function/sympy.py: N802 # imported but unused src/tensorwaves/optimizer/__init__.py: F401 radon-max-cc = 8 diff --git a/pytest.ini b/pytest.ini index 7c11bd4d..63964cc4 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,6 +16,7 @@ filterwarnings = ignore:.*the imp module is deprecated in favour of importlib.*:DeprecationWarning ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning ignore:invalid value encountered in log.*:RuntimeWarning + ignore:invalid value encountered in sqrt:RuntimeWarning norecursedirs = _build markers = diff --git a/src/tensorwaves/estimator.py b/src/tensorwaves/estimator.py index eef5fb9a..4f7114dc 100644 --- a/src/tensorwaves/estimator.py +++ b/src/tensorwaves/estimator.py @@ -7,7 +7,7 @@ import numpy as np -from tensorwaves._backend import find_function +from tensorwaves.function._backend import find_function from tensorwaves.interface import ( DataSample, Estimator, diff --git a/src/tensorwaves/_backend.py b/src/tensorwaves/function/_backend.py similarity index 80% rename from src/tensorwaves/_backend.py rename to src/tensorwaves/function/_backend.py index d7a21a27..971f4b9c 100644 --- a/src/tensorwaves/_backend.py +++ b/src/tensorwaves/function/_backend.py @@ -11,9 +11,15 @@ def find_function(function_name: str, backend: str) -> Callable: return backend_modules[function_name] if isinstance(backend_modules, (tuple, list)): for module in backend_modules: - if function_name in module.__dict__: - return module.__dict__[function_name] - raise ValueError(f"Could not find function {function_name} in backend") + if isinstance(module, dict): + module_dict = module + else: + module_dict = module.__dict__ + if function_name in module_dict: + return module_dict[function_name] + raise ValueError( + f'Could not find function "{function_name}" in backend "{backend}"' + ) def get_backend_modules( @@ -38,14 +44,15 @@ def get_backend_modules( if backend in {"numpy", "numba"}: import numpy as np - return (np, np.__dict__) + return np, np.__dict__ # returning only np.__dict__ does not work well with conditionals if backend in {"tensorflow", "tf"}: # pylint: disable=import-error # pyright: reportMissingImports=false + import tensorflow as tf import tensorflow.experimental.numpy as tnp - return tnp.__dict__ + return tnp.__dict__, tf return backend diff --git a/src/tensorwaves/function/sympy.py b/src/tensorwaves/function/sympy.py index d6aec971..308a44b6 100644 --- a/src/tensorwaves/function/sympy.py +++ b/src/tensorwaves/function/sympy.py @@ -1,6 +1,8 @@ +# pylint: disable=abstract-method invalid-name protected-access """Lambdify `sympy` expression trees to a `.Function`.""" import logging +import re from typing import ( Any, Callable, @@ -13,14 +15,10 @@ ) import sympy as sp -from sympy.printing.numpy import ( - NumPyPrinter, - _numpy_known_constants, - _numpy_known_functions, -) +from sympy.printing.numpy import NumPyPrinter from tqdm.auto import tqdm -from tensorwaves._backend import get_backend_modules, jit_compile +from tensorwaves.function._backend import get_backend_modules, jit_compile from tensorwaves.interface import ParameterValue from . import ParametrizedBackendFunction @@ -99,7 +97,13 @@ def tensorflow_lambdify() -> Callable: # pylint: disable=import-error import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false - return sp.lambdify(symbols, expression, modules=tnp, **kwargs) + return sp.lambdify( + symbols, + expression, + modules=tnp, + printer=_TensorflowPrinter, + **kwargs, + ) modules = get_backend_modules(backend) if isinstance(backend, str): @@ -222,26 +226,33 @@ def _use_progress_bar() -> bool: return logging.getLogger().level <= logging.WARNING -_jax_known_functions = { - k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items() -} -_jax_known_constants = { - k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_constants.items() -} +def _replace_module( + mapping: Dict[str, str], old: str, new: str +) -> Dict[str, str]: + return { + k: re.sub(fr"^{old}\.(.*)$", fr"{new}.\1", v) + for k, v in mapping.items() + } + +class _CustomNumPyPrinter(NumPyPrinter): + def _print_ComplexSqrt(self, expr: sp.Expr) -> str: + return expr._numpycode(self) -class _JaxPrinter(NumPyPrinter): # pylint: disable=abstract-method - # pylint: disable=invalid-name + +class _JaxPrinter(_CustomNumPyPrinter): module_imports = {"jax": {"numpy as jnp"}} _module = "jnp" - _kc = _jax_known_constants - _kf = _jax_known_functions + _kc = _replace_module(NumPyPrinter._kc, "numpy", "jnp") + _kf = _replace_module(NumPyPrinter._kf, "numpy", "jnp") + - def _print_ComplexSqrt(self, expr: sp.Expr) -> str: # noqa: N802 +class _TensorflowPrinter(_CustomNumPyPrinter): + module_imports = {"tensorflow.experimental": {"numpy as tnp"}} + _module = "tnp" + _kc = _replace_module(NumPyPrinter._kc, "numpy", "tnp") + _kf = _replace_module(NumPyPrinter._kf, "numpy", "tnp") + + def _print_ComplexSqrt(self, expr: sp.Expr) -> str: x = self._print(expr.args[0]) - return ( - "jnp.select(" - f"[jnp.less({x}, 0), True], " - f"[1j * jnp.sqrt(-{x}), jnp.sqrt({x})], " - "default=jnp.nan)" - ) + return f"sqrt({x})" diff --git a/tests/unit/function/test_ampform.py b/tests/unit/function/test_ampform.py new file mode 100644 index 00000000..a0ac5f7c --- /dev/null +++ b/tests/unit/function/test_ampform.py @@ -0,0 +1,30 @@ +# pylint: disable=import-outside-toplevel +import numpy as np +import pytest + +from tensorwaves.function._backend import find_function +from tensorwaves.function.sympy import create_parametrized_function + + +@pytest.mark.parametrize("backend", ["jax", "math", "numba", "numpy", "tf"]) +def test_complex_sqrt(backend: str): + import sympy as sp + from ampform.dynamics.math import ComplexSqrt + from numpy.lib.scimath import sqrt as complex_sqrt + + x = sp.Symbol("x") + expr = ComplexSqrt(x) + function = create_parametrized_function( + expr.doit(), parameters={}, backend=backend + ) + if backend == "math": + values = -4 + else: + linspace = find_function("linspace", backend) + kwargs = {} + if backend == "tf": + kwargs["dtype"] = find_function("complex64", backend) + values = linspace(-4, +4, 9, **kwargs) + data = {"x": values} + output_array = function(data) # type: ignore[arg-type] + np.testing.assert_almost_equal(output_array, complex_sqrt(data["x"])) diff --git a/tests/unit/function/test_backend.py b/tests/unit/function/test_backend.py new file mode 100644 index 00000000..41d4320f --- /dev/null +++ b/tests/unit/function/test_backend.py @@ -0,0 +1,25 @@ +from tensorwaves.function._backend import find_function + + +def test_find_function(): + # pylint: disable=import-error, import-outside-toplevel + # pyright: reportMissingImports=false + import jax.numpy as jnp + import numpy as np + import tensorflow as tf + import tensorflow.experimental.numpy as tnp + + assert find_function("array", backend="numpy") is np.array + assert find_function("linspace", backend="numpy") is np.linspace + assert find_function("log", backend="numpy") is np.log + assert find_function("mean", backend="numpy") is np.mean + assert find_function("mean", backend="numba") is np.mean + + assert find_function("array", backend="jax") is jnp.array + assert find_function("linspace", backend="jax") is jnp.linspace + assert find_function("mean", backend="jax") is jnp.mean + + assert find_function("array", backend="tf") is tnp.array + assert find_function("linspace", backend="tf") is tnp.linspace + assert find_function("mean", backend="tf") is tnp.mean + assert find_function("Tensor", backend="tf") is tf.Tensor diff --git a/tests/unit/test_backend.py b/tests/unit/test_backend.py deleted file mode 100644 index f436897f..00000000 --- a/tests/unit/test_backend.py +++ /dev/null @@ -1,14 +0,0 @@ -from tensorwaves._backend import find_function - - -def test_find_function(): - # pylint: disable=import-error, import-outside-toplevel - # pyright: reportMissingImports=false - import jax.numpy as jnp - import numpy as np - import tensorflow.experimental.numpy as tnp - - assert find_function("mean", backend="numpy") is np.mean - assert find_function("log", backend="numpy") is np.log - assert find_function("mean", backend="tf") is tnp.mean - assert find_function("mean", backend="jax") is jnp.mean