Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: jit-compile recombined fast_lambdify function #361

Merged
merged 7 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/usage/step3.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"source": [
"As explained in the {doc}`previous step <step2>`, a {class}`.ParametrizedFunction` can compute a list of intensities (real numbers) for an input {obj}`.DataSample`. At this stage, we want to optimize the parameters of this {class}`.ParametrizedFunction`, so that it matches the distribution of our data sample. This is what we call 'fitting'.\n",
"\n",
"First, we load the relevant data from the previous steps."
"First, we load the relevant data from the previous steps. Notice that we use {func}`.create_parametrized_function` with the argument `max_complexity`, which speeds up lambdification (see {doc}`/usage/faster-lambdify`)."
]
},
{
Expand Down Expand Up @@ -89,6 +89,7 @@
" expression=model.expression.doit(),\n",
" parameters=model.parameter_defaults,\n",
" backend=\"jax\",\n",
" max_complexity=100,\n",
")"
]
},
Expand Down
18 changes: 18 additions & 0 deletions src/tensorwaves/_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Computational back-end handling."""


from functools import partial
from typing import Callable, Union


Expand Down Expand Up @@ -47,3 +49,19 @@ def get_backend_modules(
return tnp.__dict__

return backend


def jit_compile(backend: str) -> Callable:
# pylint: disable=import-outside-toplevel
backend = backend.lower()
if backend == "jax":
import jax

return jax.jit

if backend == "numba":
import numba

return partial(numba.jit, forceobj=True, parallel=True)

return lambda x: x
243 changes: 120 additions & 123 deletions src/tensorwaves/function/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,83 +21,11 @@
)
from tqdm.auto import tqdm

from tensorwaves._backend import get_backend_modules
from tensorwaves._backend import get_backend_modules, jit_compile
from tensorwaves.interface import ParameterValue

from . import ParametrizedBackendFunction

_jax_known_functions = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items()
}
_jax_known_constants = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_constants.items()
}


class _JaxPrinter(NumPyPrinter): # pylint: disable=abstract-method
# pylint: disable=invalid-name
module_imports = {"jax": {"numpy as jnp"}}
_module = "jnp"
_kc = _jax_known_constants
_kf = _jax_known_functions

def _print_ComplexSqrt(self, expr: sp.Expr) -> str: # noqa: N802
x = self._print(expr.args[0])
return (
"jnp.select("
f"[jnp.less({x}, 0), True], "
f"[1j * jnp.sqrt(-{x}), jnp.sqrt({x})], "
"default=jnp.nan)"
)


def split_expression(
expression: sp.Expr,
max_complexity: int,
min_complexity: int = 0,
) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]:
"""Split an expression into a 'top expression' and several sub-expressions.

Replace nodes in the expression tree of a `sympy.Expr
<sympy.core.expr.Expr>` that lie within a certain complexity range (see
:meth:`~sympy.core.basic.Basic.count_ops`) with symbols and keep a mapping
of each to these symbols to the sub-expressions that they replaced.

.. seealso:: :doc:`/usage/faster-lambdify`
"""
i = 0
symbol_mapping: Dict[sp.Symbol, sp.Expr] = {}
n_operations = sp.count_ops(expression)
if n_operations < max_complexity:
return expression, symbol_mapping
progress_bar = tqdm(
total=n_operations,
desc="Splitting expression",
unit="node",
disable=not _use_progress_bar(),
)

def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
nonlocal i
for arg in sub_expression.args:
complexity = sp.count_ops(arg)
if min_complexity < complexity < max_complexity:
progress_bar.update(n=complexity)
symbol = sp.Symbol(f"f{i}")
i += 1
symbol_mapping[symbol] = arg
sub_expression = sub_expression.xreplace({arg: symbol})
else:
new_arg = recursive_split(arg)
sub_expression = sub_expression.xreplace({arg: new_arg})
return sub_expression

top_expression = recursive_split(expression)
remainder = progress_bar.total - progress_bar.n
progress_bar.update(n=remainder) # pylint crashes if total is set directly
progress_bar.close()
return top_expression, symbol_mapping


def create_parametrized_function(
expression: sp.Expr,
Expand Down Expand Up @@ -131,6 +59,58 @@ def create_parametrized_function(
)


def lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
backend: Union[str, tuple, dict],
**kwargs: Any,
) -> Callable:
"""A wrapper around :func:`~sympy.utilities.lambdify.lambdify`."""
# pylint: disable=import-outside-toplevel, too-many-return-statements
def jax_lambdify() -> Callable:
return jit_compile(backend="jax")(
sp.lambdify(
symbols,
expression,
modules=modules,
printer=_JaxPrinter,
**kwargs,
)
)

def numba_lambdify() -> Callable:
return jit_compile(backend="numba")(
sp.lambdify(symbols, expression, modules="numpy", **kwargs)
)

def tensorflow_lambdify() -> Callable:
# pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false

return sp.lambdify(symbols, expression, modules=tnp, **kwargs)

modules = get_backend_modules(backend)
if isinstance(backend, str):
if backend == "jax":
return jax_lambdify()
if backend == "numba":
return numba_lambdify()
if backend in {"tensorflow", "tf"}:
return tensorflow_lambdify()

if isinstance(backend, tuple):
if any("jax" in x.__name__ for x in backend):
return jax_lambdify()
if any("numba" in x.__name__ for x in backend):
return numba_lambdify()
if any(
"tensorflow" in x.__name__ or "tf" in x.__name__ for x in backend
):
return tensorflow_lambdify()

return sp.lambdify(symbols, expression, modules=modules, **kwargs)


def fast_lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
Expand Down Expand Up @@ -167,71 +147,88 @@ def fast_lambdify(
sub_function = lambdify(sub_expression, symbols, backend, **kwargs)
sub_functions.append(sub_function)

@jit_compile(backend) # type: ignore[arg-type]
def recombined_function(*args: Any) -> Any:
new_args = [sub_function(*args) for sub_function in sub_functions]
return top_function(*new_args)

return recombined_function


def lambdify(
def split_expression(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
backend: Union[str, tuple, dict],
**kwargs: Any,
) -> Callable:
"""A wrapper around :func:`~sympy.utilities.lambdify.lambdify`."""
# pylint: disable=import-outside-toplevel,too-many-return-statements
def jax_lambdify() -> Callable:
import jax
max_complexity: int,
min_complexity: int = 1,
) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]:
"""Split an expression into a 'top expression' and several sub-expressions.

return jax.jit(
sp.lambdify(
symbols,
expression,
modules=modules,
printer=_JaxPrinter,
**kwargs,
)
)
Replace nodes in the expression tree of a `sympy.Expr
<sympy.core.expr.Expr>` that lie within a certain complexity range (see
:meth:`~sympy.core.basic.Basic.count_ops`) with symbols and keep a mapping
of each to these symbols to the sub-expressions that they replaced.

def numba_lambdify() -> Callable:
# pylint: disable=import-error
import numba
.. seealso:: :doc:`/usage/faster-lambdify`
"""
i = 0
symbol_mapping: Dict[sp.Symbol, sp.Expr] = {}
n_operations = sp.count_ops(expression)
if max_complexity <= 0 or n_operations < max_complexity:
return expression, symbol_mapping
progress_bar = tqdm(
total=n_operations,
desc="Splitting expression",
unit="node",
disable=not _use_progress_bar(),
)

return numba.jit(
sp.lambdify(symbols, expression, modules="numpy", **kwargs),
forceobj=True,
parallel=True,
)
def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
nonlocal i
for arg in sub_expression.args:
complexity = sp.count_ops(arg)
if min_complexity <= complexity <= max_complexity:
progress_bar.update(n=complexity)
symbol = sp.Symbol(f"f{i}")
i += 1
symbol_mapping[symbol] = arg
sub_expression = sub_expression.xreplace({arg: symbol})
else:
new_arg = recursive_split(arg)
sub_expression = sub_expression.xreplace({arg: new_arg})
return sub_expression

def tensorflow_lambdify() -> Callable:
# pylint: disable=import-error
import tensorflow.experimental.numpy as tnp # pyright: reportMissingImports=false
top_expression = recursive_split(expression)
remaining_symbols = top_expression.free_symbols - set(symbol_mapping)
symbol_mapping.update({s: s for s in remaining_symbols})
remainder = progress_bar.total - progress_bar.n
progress_bar.update(n=remainder) # pylint crashes if total is set directly
progress_bar.close()
return top_expression, symbol_mapping

return sp.lambdify(symbols, expression, modules=tnp, **kwargs)

modules = get_backend_modules(backend)
if isinstance(backend, str):
if backend == "jax":
return jax_lambdify()
if backend == "numba":
return numba_lambdify()
if backend in {"tensorflow", "tf"}:
return tensorflow_lambdify()
def _use_progress_bar() -> bool:
return logging.getLogger().level <= logging.WARNING

if isinstance(backend, tuple):
if any("jax" in x.__name__ for x in backend):
return jax_lambdify()
if any("numba" in x.__name__ for x in backend):
return numba_lambdify()
if any(
"tensorflow" in x.__name__ or "tf" in x.__name__ for x in backend
):
return tensorflow_lambdify()

return sp.lambdify(symbols, expression, modules=modules, **kwargs)
_jax_known_functions = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items()
}
_jax_known_constants = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_constants.items()
}


def _use_progress_bar() -> bool:
return logging.getLogger().level <= logging.WARNING
class _JaxPrinter(NumPyPrinter): # pylint: disable=abstract-method
# pylint: disable=invalid-name
module_imports = {"jax": {"numpy as jnp"}}
_module = "jnp"
_kc = _jax_known_constants
_kf = _jax_known_functions

def _print_ComplexSqrt(self, expr: sp.Expr) -> str: # noqa: N802
x = self._print(expr.args[0])
return (
"jnp.select("
f"[jnp.less({x}, 0), True], "
f"[1j * jnp.sqrt(-{x}), jnp.sqrt({x})], "
"default=jnp.nan)"
)
Loading