Skip to content

Commit

Permalink
fix: jit-compile recombined fast_lambdify function
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Dec 1, 2021
1 parent b920c77 commit 80e542c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 17 deletions.
10 changes: 6 additions & 4 deletions src/tensorwaves/_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Computational back-end handling."""


from functools import partial
from typing import Callable, Union


Expand Down Expand Up @@ -49,17 +51,17 @@ def get_backend_modules(
return backend


def jit_compile(function: Callable, backend: str) -> Callable:
def jit_compile(backend: str) -> Callable:
# pylint: disable=import-outside-toplevel
backend = backend.lower()
if backend == "jax":
import jax

return jax.jit(function)
return jax.jit

if backend == "numba":
import numba

return numba.jit(function, forceobj=True, parallel=True)
return partial(numba.jit, forceobj=True, parallel=True)

raise NotImplementedError(f"Cannot JIT-compile with backend {backend}")
return lambda x: x
11 changes: 5 additions & 6 deletions src/tensorwaves/function/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,19 @@ def lambdify(
"""A wrapper around :func:`~sympy.utilities.lambdify.lambdify`."""
# pylint: disable=import-outside-toplevel, too-many-return-statements
def jax_lambdify() -> Callable:
return jit_compile(
return jit_compile(backend="jax")(
sp.lambdify(
symbols,
expression,
modules=modules,
printer=_JaxPrinter,
**kwargs,
),
backend="jax",
)
)

def numba_lambdify() -> Callable:
return jit_compile(
sp.lambdify(symbols, expression, modules="numpy", **kwargs),
backend="numba",
return jit_compile(backend="numba")(
sp.lambdify(symbols, expression, modules="numpy", **kwargs)
)

def tensorflow_lambdify() -> Callable:
Expand Down Expand Up @@ -149,6 +147,7 @@ def fast_lambdify(
sub_function = lambdify(sub_expression, symbols, backend, **kwargs)
sub_functions.append(sub_function)

@jit_compile(backend) # type: ignore[arg-type]
def recombined_function(*args: Any) -> Any:
new_args = [sub_function(*args) for sub_function in sub_functions]
return top_function(*new_args)
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/function/test_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def test_fast_lambdify(backend: str, max_complexity: int):

func_repr = str(function)
if 0 < max_complexity <= 4:
assert func_repr.startswith("<function fast_lambdify.<locals>")
repr_start = "<function fast_lambdify.<locals>"
else:
repr_start = "<function _lambdifygenerated"
if backend == "jax":
if sys.version_info >= (3, 7):
repr_start = "<CompiledFunction of " + repr_start
else:
repr_start = "<CompiledFunction object at 0x"
assert func_repr.startswith(repr_start)
if backend == "jax":
if sys.version_info >= (3, 7):
repr_start = "<CompiledFunction of " + repr_start
else:
repr_start = "<CompiledFunction object at 0x"
assert func_repr.startswith(repr_start)

data = (
4,
Expand Down

0 comments on commit 80e542c

Please sign in to comment.