Skip to content

Commit

Permalink
Make JAX backend import lazy. (#135)
Browse files Browse the repository at this point in the history
* Make JAX backend import lazy.

We observed the following surprising behavior: importing tensorflow also imports jax, if installed. This happens because tensorflow has an opt_einsum dependency that is imported eagerly, and the opt_einsum jax backend eagerly imports jax. To avoid this, make the jax import from opt_einsum lazy.

* Update miniconda md5 path to fix CI builds.
  • Loading branch information
hawkinsp authored Apr 15, 2020
1 parent bc72874 commit 7756811
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion devtools/travis-ci/before_install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ else
MINICONDA=Miniconda3-latest-Linux-x86_64.sh
fi
MINICONDA_HOME=$HOME/miniconda
MINICONDA_MD5=$(curl -s https://repo.continuum.io/miniconda/ | grep -A3 $MINICONDA | sed -n '4p' | sed -n 's/ *<td>\(.*\)<\/td> */\1/p')
MINICONDA_MD5=$(wget -qO- https://repo.continuum.io/miniconda/ | grep -A3 $MINICONDA | sed -n '4p' | sed -n 's/ *<td>\(.*\)<\/td> */\1/p')
wget -q https://repo.continuum.io/miniconda/$MINICONDA
if [[ $MINICONDA_MD5 != $(md5sum $MINICONDA | cut -d ' ' -f 1) ]]; then
echo "Miniconda MD5 mismatch"
Expand Down
17 changes: 13 additions & 4 deletions opt_einsum/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@

__all__ = ["build_expression", "evaluate_constants"]

try:

_JAX = None


def _get_jax_and_to_jax():
global _JAX
if _JAX is None:
import jax

@to_backend_cache_wrap
@jax.jit
def to_jax(x):
return x

except ImportError:
pass
_JAX = jax, to_jax

return _JAX


def build_expression(_, expr): # pragma: no cover
"""Build a jax function based on ``arrays`` and ``expr``.
"""
import jax
jax, _ = _get_jax_and_to_jax()

jax_expr = jax.jit(expr._contract)

Expand All @@ -37,4 +44,6 @@ def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to jax arrays, and perform any possible
constant contractions.
"""
jax, to_jax = _get_jax_and_to_jax()

return expr(*[to_jax(x) for x in const_arrays], backend='jax', evaluate_constants=True)

0 comments on commit 7756811

Please sign in to comment.