diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 0055a79ea6fa..4e113872b7cd 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -568,13 +568,3 @@ def __missing__(self, key): backend_specific_translations: Dict[str, _TranslationRuleAdapter] backend_specific_translations = _BackendSpecificTranslationsAdapter() - -# TODO(phawkins): remove lower_fun completely after updating users. -def lower_fun(fun: Callable, *, multiple_results: bool, backend=None, - new_style: bool = False) -> Callable: - def f(*args, **kw): - raise RuntimeError("XLA translation rules are deprecated and " - "jax.interpreters.xla.lower_fun is no longer supported. " - "Add an MLIR lowering via jax.interpreters.mlir " - "instead.") - return f diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index e6fd5b394357..0312df7017e4 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -40,7 +40,6 @@ extend_axis_env as extend_axis_env, extend_name_stack as extend_name_stack, jaxpr_collectives as jaxpr_collectives, - lower_fun as lower_fun, make_device_array as make_device_array, make_op_metadata as make_op_metadata, new_name_stack as new_name_stack,