Skip to content

Commit

Permalink
style: improve function order definition in sympy module
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Dec 1, 2021
1 parent 6f20845 commit efb2026
Showing 1 changed file with 118 additions and 117 deletions.
235 changes: 118 additions & 117 deletions src/tensorwaves/function/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,80 +26,6 @@

from . import ParametrizedBackendFunction

_jax_known_functions = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items()
}
_jax_known_constants = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_constants.items()
}


class _JaxPrinter(NumPyPrinter): # pylint: disable=abstract-method
# pylint: disable=invalid-name
module_imports = {"jax": {"numpy as jnp"}}
_module = "jnp"
_kc = _jax_known_constants
_kf = _jax_known_functions

def _print_ComplexSqrt(self, expr: sp.Expr) -> str: # noqa: N802
x = self._print(expr.args[0])
return (
"jnp.select("
f"[jnp.less({x}, 0), True], "
f"[1j * jnp.sqrt(-{x}), jnp.sqrt({x})], "
"default=jnp.nan)"
)


def split_expression(
expression: sp.Expr,
max_complexity: int,
min_complexity: int = 1,
) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]:
"""Split an expression into a 'top expression' and several sub-expressions.
Replace nodes in the expression tree of a `sympy.Expr
<sympy.core.expr.Expr>` that lie within a certain complexity range (see
:meth:`~sympy.core.basic.Basic.count_ops`) with symbols and keep a mapping
of each to these symbols to the sub-expressions that they replaced.
.. seealso:: :doc:`/usage/faster-lambdify`
"""
i = 0
symbol_mapping: Dict[sp.Symbol, sp.Expr] = {}
n_operations = sp.count_ops(expression)
if max_complexity <= 0 or n_operations < max_complexity:
return expression, symbol_mapping
progress_bar = tqdm(
total=n_operations,
desc="Splitting expression",
unit="node",
disable=not _use_progress_bar(),
)

def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
nonlocal i
for arg in sub_expression.args:
complexity = sp.count_ops(arg)
if min_complexity <= complexity <= max_complexity:
progress_bar.update(n=complexity)
symbol = sp.Symbol(f"f{i}")
i += 1
symbol_mapping[symbol] = arg
sub_expression = sub_expression.xreplace({arg: symbol})
else:
new_arg = recursive_split(arg)
sub_expression = sub_expression.xreplace({arg: new_arg})
return sub_expression

top_expression = recursive_split(expression)
remaining_symbols = top_expression.free_symbols - set(symbol_mapping)
symbol_mapping.update({s: s for s in remaining_symbols})
remainder = progress_bar.total - progress_bar.n
progress_bar.update(n=remainder) # pylint crashes if total is set directly
progress_bar.close()
return top_expression, symbol_mapping


def create_parametrized_function(
expression: sp.Expr,
Expand Down Expand Up @@ -133,49 +59,6 @@ def create_parametrized_function(
)


def fast_lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
backend: Union[str, tuple, dict],
*,
min_complexity: int = 0,
max_complexity: int,
**kwargs: Any,
) -> Callable:
"""Speed up :func:`.lambdify` with :func:`.split_expression`.
.. seealso:: :doc:`/usage/faster-lambdify`
"""
top_expression, sub_expressions = split_expression(
expression,
min_complexity=min_complexity,
max_complexity=max_complexity,
)
if not sub_expressions:
return lambdify(top_expression, symbols, backend, **kwargs)

sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name)
top_function = lambdify(
top_expression, sorted_top_symbols, backend, **kwargs
)
sub_functions: List[Callable] = []
for symbol in tqdm(
iterable=sorted_top_symbols,
desc="Lambdifying sub-expressions",
unit="expr",
disable=not _use_progress_bar(),
):
sub_expression = sub_expressions[symbol]
sub_function = lambdify(sub_expression, symbols, backend, **kwargs)
sub_functions.append(sub_function)

def recombined_function(*args: Any) -> Any:
new_args = [sub_function(*args) for sub_function in sub_functions]
return top_function(*new_args)

return recombined_function


def lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
Expand Down Expand Up @@ -230,5 +113,123 @@ def tensorflow_lambdify() -> Callable:
return sp.lambdify(symbols, expression, modules=modules, **kwargs)


def fast_lambdify(
expression: sp.Expr,
symbols: Sequence[sp.Symbol],
backend: Union[str, tuple, dict],
*,
min_complexity: int = 0,
max_complexity: int,
**kwargs: Any,
) -> Callable:
"""Speed up :func:`.lambdify` with :func:`.split_expression`.
.. seealso:: :doc:`/usage/faster-lambdify`
"""
top_expression, sub_expressions = split_expression(
expression,
min_complexity=min_complexity,
max_complexity=max_complexity,
)
if not sub_expressions:
return lambdify(top_expression, symbols, backend, **kwargs)

sorted_top_symbols = sorted(sub_expressions, key=lambda s: s.name)
top_function = lambdify(
top_expression, sorted_top_symbols, backend, **kwargs
)
sub_functions: List[Callable] = []
for symbol in tqdm(
iterable=sorted_top_symbols,
desc="Lambdifying sub-expressions",
unit="expr",
disable=not _use_progress_bar(),
):
sub_expression = sub_expressions[symbol]
sub_function = lambdify(sub_expression, symbols, backend, **kwargs)
sub_functions.append(sub_function)

def recombined_function(*args: Any) -> Any:
new_args = [sub_function(*args) for sub_function in sub_functions]
return top_function(*new_args)

return recombined_function


def split_expression(
expression: sp.Expr,
max_complexity: int,
min_complexity: int = 1,
) -> Tuple[sp.Expr, Dict[sp.Symbol, sp.Expr]]:
"""Split an expression into a 'top expression' and several sub-expressions.
Replace nodes in the expression tree of a `sympy.Expr
<sympy.core.expr.Expr>` that lie within a certain complexity range (see
:meth:`~sympy.core.basic.Basic.count_ops`) with symbols and keep a mapping
of each to these symbols to the sub-expressions that they replaced.
.. seealso:: :doc:`/usage/faster-lambdify`
"""
i = 0
symbol_mapping: Dict[sp.Symbol, sp.Expr] = {}
n_operations = sp.count_ops(expression)
if max_complexity <= 0 or n_operations < max_complexity:
return expression, symbol_mapping
progress_bar = tqdm(
total=n_operations,
desc="Splitting expression",
unit="node",
disable=not _use_progress_bar(),
)

def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
nonlocal i
for arg in sub_expression.args:
complexity = sp.count_ops(arg)
if min_complexity <= complexity <= max_complexity:
progress_bar.update(n=complexity)
symbol = sp.Symbol(f"f{i}")
i += 1
symbol_mapping[symbol] = arg
sub_expression = sub_expression.xreplace({arg: symbol})
else:
new_arg = recursive_split(arg)
sub_expression = sub_expression.xreplace({arg: new_arg})
return sub_expression

top_expression = recursive_split(expression)
remaining_symbols = top_expression.free_symbols - set(symbol_mapping)
symbol_mapping.update({s: s for s in remaining_symbols})
remainder = progress_bar.total - progress_bar.n
progress_bar.update(n=remainder) # pylint crashes if total is set directly
progress_bar.close()
return top_expression, symbol_mapping


def _use_progress_bar() -> bool:
return logging.getLogger().level <= logging.WARNING


_jax_known_functions = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_functions.items()
}
_jax_known_constants = {
k: v.replace("numpy.", "jnp.") for k, v in _numpy_known_constants.items()
}


class _JaxPrinter(NumPyPrinter): # pylint: disable=abstract-method
# pylint: disable=invalid-name
module_imports = {"jax": {"numpy as jnp"}}
_module = "jnp"
_kc = _jax_known_constants
_kf = _jax_known_functions

def _print_ComplexSqrt(self, expr: sp.Expr) -> str: # noqa: N802
x = self._print(expr.args[0])
return (
"jnp.select("
f"[jnp.less({x}, 0), True], "
f"[1j * jnp.sqrt(-{x}), jnp.sqrt({x})], "
"default=jnp.nan)"
)

0 comments on commit efb2026

Please sign in to comment.