Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated partial eval functions (prep work for stackless) #23768

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 0 additions & 60 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,43 +730,6 @@ def get_referent(self):
return self


@profiler.annotate_function
def trace_to_jaxpr(
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
instantiate: bool | Sequence[bool] = False,
) -> tuple[Jaxpr, list[PartialVal], list[core.Value]]:
"""
Partially evaluate a function, building a jaxpr for un-evaluated computation.

Args:
fun: lu.WrappedFun representing the function to be partially evaluated. The
function must be flattened, in the sense of accepting jaxpr type arguments
and returning a flat list of jaxpr type outputs.
pvals: sequence of PartialVals of length equal to the number of inputs to
`fun` indicating which inputs are known or unknown.
instantiate: optional bool or sequence of bools of length equal to the
number of outputs of `fun` indicating which outputs should be forced to be
treated as unknown and hence instantiated in the jaxpr. If a single bool,
the value is applied to all outputs. Default False.

Returns:
A triple where the first element is a jaxpr representing the computation
which depends on unknown inputs; the second element is a list of PartialVals
of length equal to the length of the output of `fun` representing which
outputs are known and unknown (along with their values and abstract values,
respectively); the third element is a list of known residual values. The
returned jaxpr takes as inputs the known residual values followed by values
of the originally unknown inputs.
"""
current_name_stack = source_info_util.current_name_stack()
with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
fun = trace_to_subjaxpr(fun, main, instantiate)
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
assert not env
del main, fun, env

return jaxpr, out_pvals, consts

@profiler.annotate_function
def trace_to_jaxpr_nounits(
fun: lu.WrappedFun, pvals: Sequence[PartialVal],
Expand Down Expand Up @@ -2784,29 +2747,6 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
return prim.bind(*subfuns, *args, **bind_params)


# TODO(mattjj): the following are deprecated; update callers to _nounits version
# See https://github.com/google/jax/pull/9498
@lu.transformation
def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool],
pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in pvals), pvals
trace = main.with_cur_sublevel()
in_tracers = map(trace.new_arg, pvals)
ans = yield in_tracers, {}
assert isinstance(ans, (list, tuple)), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), (
f"Got unexpected return type when tracing function to jaxpr: {ans}")
instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate
out_tracers = map(trace.full_raise, map(core.full_lower, ans))
out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers)
out_pvals = [t.pval for t in out_tracers]
del trace, in_tracers, out_tracers
yield jaxpr, (out_pvals, consts, env)

partial_eval_jaxpr: Callable

def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
if instantiate:
return trace.instantiate_const(trace.full_raise(tracer))
Expand Down
2 changes: 0 additions & 2 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@
recipe_to_eqn as recipe_to_eqn,
result_info as result_info,
sig_info as sig_info,
trace_to_jaxpr as trace_to_jaxpr,
trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic,
trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2,
trace_to_jaxpr_final as trace_to_jaxpr_final,
trace_to_jaxpr_final2 as trace_to_jaxpr_final2,
trace_to_jaxpr_nounits as trace_to_jaxpr_nounits,
trace_to_subjaxpr as trace_to_subjaxpr,
trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic,
trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2,
trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits,
Expand Down