Skip to content

Commit

Permalink
Remove jax.interpreters.xla.lower_fun.
Browse files Browse the repository at this point in the history
This function has been a stub that does nothing useful for a long time, and the only user I can find is Equinox which already guards this with a hasattr(xla, 'lower_fun') guard.

PiperOrigin-RevId: 510142446
  • Loading branch information
hawkinsp authored and jax authors committed Feb 16, 2023
1 parent a9e886f commit c6a99b6
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 11 deletions.
10 changes: 0 additions & 10 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,13 +568,3 @@ def __missing__(self, key):

backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()

# TODO(phawkins): remove lower_fun completely after updating users.
def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
new_style: bool = False) -> Callable:
def f(*args, **kw):
raise RuntimeError("XLA translation rules are deprecated and "
"jax.interpreters.xla.lower_fun is no longer supported. "
"Add an MLIR lowering via jax.interpreters.mlir "
"instead.")
return f
1 change: 0 additions & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
extend_axis_env as extend_axis_env,
extend_name_stack as extend_name_stack,
jaxpr_collectives as jaxpr_collectives,
lower_fun as lower_fun,
make_device_array as make_device_array,
make_op_metadata as make_op_metadata,
new_name_stack as new_name_stack,
Expand Down

0 comments on commit c6a99b6

Please sign in to comment.