Skip to content

Commit

Permalink
fix: lambdify ComplexSqrt to TensorFlow (#365)
Browse files Browse the repository at this point in the history
* chore: define custom _TensorflowPrinter
* chore: extract _replace_module function
* chore: forward JAX ComplexSqrt printer to numpycode
* chore: move _backend to function sub-module
* feat: expand available functions in TF printer
* fix: improve find_function exception message
* fix: lambdify ComplexSqrt to TensorFlow
* test: lambdify ComplexSqrt with different backends
* test: try more functions in test_find_function
  • Loading branch information
redeboer authored Dec 2, 2021
1 parent 4aa7b7e commit 3e50f2d
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ ignore =
extend-select =
TI100
per-file-ignores =
# printer methods
src/tensorwaves/function/sympy.py: N802
# imported but unused
src/tensorwaves/optimizer/__init__.py: F401
radon-max-cc = 8
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ filterwarnings =
ignore:.*the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning
ignore:invalid value encountered in log.*:RuntimeWarning
ignore:invalid value encountered in sqrt:RuntimeWarning
norecursedirs =
_build
markers =
Expand Down
2 changes: 1 addition & 1 deletion src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from tensorwaves._backend import find_function
from tensorwaves.function._backend import find_function
from tensorwaves.interface import (
DataSample,
Estimator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ def find_function(function_name: str, backend: str) -> Callable:
return backend_modules[function_name]
if isinstance(backend_modules, (tuple, list)):
for module in backend_modules:
if function_name in module.__dict__:
return module.__dict__[function_name]
raise ValueError(f"Could not find function {function_name} in backend")
if isinstance(module, dict):
module_dict = module
else:
module_dict = module.__dict__
if function_name in module_dict:
return module_dict[function_name]
raise ValueError(
f'Could not find function "{function_name}" in backend "{backend}"'
)


def get_backend_modules(
Expand All @@ -38,14 +44,15 @@ def get_backend_modules(
if backend in {"numpy", "numba"}:
import numpy as np

return (np, np.__dict__)
return np, np.__dict__
# returning only np.__dict__ does not work well with conditionals
if backend in {"tensorflow", "tf"}:
# pylint: disable=import-error
# pyright: reportMissingImports=false
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

return tnp.__dict__
return tnp.__dict__, tf

return backend

Expand Down
59 changes: 35 additions & 24 deletions src/tensorwaves/function/sympy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=abstract-method invalid-name protected-access
"""Lambdify `sympy` expression trees to a `.Function`."""

import logging
import re
from typing import (
Any,
Callable,
Expand All @@ -13,14 +15,10 @@
)

import sympy as sp
from sympy.printing.numpy import (
NumPyPrinter,
_numpy_known_constants,
_numpy_known_functions,
)
from sympy.printing.numpy import NumPyPrinter
from tqdm.auto import tqdm

from tensorwaves._backend import get_backend_modules, jit_compile
from tensorwaves.function._backend import get_backend_modules, jit_compile
from tensorwaves.interface import ParameterValue

from . import ParametrizedBackendFunction
Expand Down Expand Up @@ -99,7 +97,13 @@ def tensorflow_lambdify() -> Callable:
# pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false

return sp.lambdify(symbols, expression, modules=tnp, **kwargs)
return sp.lambdify(
symbols,
expression,
modules=tnp,
printer=_TensorflowPrinter,
**kwargs,
)

modules = get_backend_modules(backend)
if isinstance(backend, str):
Expand Down Expand Up @@ -222,26 +226,33 @@ def _use_progress_bar() -> bool:
return logging.getLogger().level <= logging.WARNING


_jax_known_functions = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items()
}
_jax_known_constants = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_constants.items()
}
def _replace_module(
mapping: Dict[str, str], old: str, new: str
) -> Dict[str, str]:
return {
k: re.sub(fr"^{old}\.(.*)$", fr"{new}.\1", v)
for k, v in mapping.items()
}


class _CustomNumPyPrinter(NumPyPrinter):
def _print_ComplexSqrt(self, expr: sp.Expr) -> str:
return expr._numpycode(self)

class _JaxPrinter(NumPyPrinter): # pylint: disable=abstract-method
# pylint: disable=invalid-name

class _JaxPrinter(_CustomNumPyPrinter):
module_imports = {"jax": {"numpy as jnp"}}
_module = "jnp"
_kc = _jax_known_constants
_kf = _jax_known_functions
_kc = _replace_module(NumPyPrinter._kc, "numpy", "jnp")
_kf = _replace_module(NumPyPrinter._kf, "numpy", "jnp")


def _print_ComplexSqrt(self, expr: sp.Expr) -> str: # noqa: N802
class _TensorflowPrinter(_CustomNumPyPrinter):
module_imports = {"tensorflow.experimental": {"numpy as tnp"}}
_module = "tnp"
_kc = _replace_module(NumPyPrinter._kc, "numpy", "tnp")
_kf = _replace_module(NumPyPrinter._kf, "numpy", "tnp")

def _print_ComplexSqrt(self, expr: sp.Expr) -> str:
x = self._print(expr.args[0])
return (
"jnp.select("
f"[jnp.less({x}, 0), True], "
f"[1j * jnp.sqrt(-{x}), jnp.sqrt({x})], "
"default=jnp.nan)"
)
return f"sqrt({x})"
30 changes: 30 additions & 0 deletions tests/unit/function/test_ampform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# pylint: disable=import-outside-toplevel
import numpy as np
import pytest

from tensorwaves.function._backend import find_function
from tensorwaves.function.sympy import create_parametrized_function


@pytest.mark.parametrize("backend", ["jax", "math", "numba", "numpy", "tf"])
def test_complex_sqrt(backend: str):
import sympy as sp
from ampform.dynamics.math import ComplexSqrt
from numpy.lib.scimath import sqrt as complex_sqrt

x = sp.Symbol("x")
expr = ComplexSqrt(x)
function = create_parametrized_function(
expr.doit(), parameters={}, backend=backend
)
if backend == "math":
values = -4
else:
linspace = find_function("linspace", backend)
kwargs = {}
if backend == "tf":
kwargs["dtype"] = find_function("complex64", backend)
values = linspace(-4, +4, 9, **kwargs)
data = {"x": values}
output_array = function(data) # type: ignore[arg-type]
np.testing.assert_almost_equal(output_array, complex_sqrt(data["x"]))
25 changes: 25 additions & 0 deletions tests/unit/function/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from tensorwaves.function._backend import find_function


def test_find_function():
# pylint: disable=import-error, import-outside-toplevel
# pyright: reportMissingImports=false
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

assert find_function("array", backend="numpy") is np.array
assert find_function("linspace", backend="numpy") is np.linspace
assert find_function("log", backend="numpy") is np.log
assert find_function("mean", backend="numpy") is np.mean
assert find_function("mean", backend="numba") is np.mean

assert find_function("array", backend="jax") is jnp.array
assert find_function("linspace", backend="jax") is jnp.linspace
assert find_function("mean", backend="jax") is jnp.mean

assert find_function("array", backend="tf") is tnp.array
assert find_function("linspace", backend="tf") is tnp.linspace
assert find_function("mean", backend="tf") is tnp.mean
assert find_function("Tensor", backend="tf") is tf.Tensor
14 changes: 0 additions & 14 deletions tests/unit/test_backend.py

This file was deleted.

0 comments on commit 3e50f2d

Please sign in to comment.