Skip to content

Commit

Permalink
fix for invalid caching of CodeInfo from typeinf_ext (#51872)
Browse files Browse the repository at this point in the history
When inference decided it was not necessary to cache the object, it also
skipped all of the work required to make the code valid, which
typeinf_ext depends upon. This resulted in caching invalid data, causing
effects tests to break unpredictably. This change ensures that we always
update the InferenceState with the final result (when `must_be_codeinf`
is true), so that typeinf_ext can get the correct results out of it for
internal codegen use. Previously we were disregarding that flag in some
cases.

Fixes one of the issues uncovered in #51860
  • Loading branch information
vtjnash authored Oct 25, 2023
1 parent d0c4284 commit bde62ad
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 39 deletions.
61 changes: 25 additions & 36 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,28 +218,28 @@ function typeinf(interp::NativeInterpreter, frame::InferenceState)
end
typeinf(interp::AbstractInterpreter, frame::InferenceState) = _typeinf(interp, frame)

function finish!(interp::AbstractInterpreter, caller::InferenceResult)
# If we didn't transform the src for caching, we may have to transform
# it anyway for users like typeinf_ext. Do that here.
opt = caller.src
if opt isa OptimizationState{typeof(interp)} # implies `may_optimize(interp) === true`
if opt.ir !== nothing
if caller.must_be_codeinf
caller.src = ir_to_codeinf!(opt)
elseif is_inlineable(opt.src)
# TODO: If the CFG is too big, inlining becomes more expensive and if we're going to
# use this IR over and over, it's worth simplifying it. Round trips through
# CodeInstance do this implicitly, since they recompute the CFG, so try to
# match that behavior here.
# ir = cfg_simplify!(opt.ir)
caller.src = opt.ir
else
# Not cached and not inlineable - drop the ir
caller.src = nothing
end
end
function finish!(interp::AbstractInterpreter, caller::InferenceState)
result = caller.result
valid_worlds = result.valid_worlds
if last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(result, caller.stmt_edges[1])
end
opt = result.src
if opt isa OptimizationState && result.must_be_codeinf
result.src = opt = ir_to_codeinf!(opt)
end
return caller.src
if opt isa CodeInfo
opt.min_world = first(valid_worlds)
opt.max_world = last(valid_worlds)
caller.src = opt
else
# In this case caller.src is invalid for clients (such as typeinf_ext) to use
# but that is what !must_be_codeinf permits
# This is hopefully unreachable when must_be_codeinf is true
end
return
end

function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
Expand All @@ -266,17 +266,12 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
end
for caller in frames
(; result ) = caller
valid_worlds = result.valid_worlds
if last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(result, caller.stmt_edges[1])
end
finish!(caller.interp, caller)
if caller.cached
cache_result!(caller.interp, result)
cache_result!(caller.interp, caller.result)
end
finish!(caller.interp, result)
# n.b. We do not drop result.src here, even though that wastes memory while it is still in the local cache
# since the user might have requested call-site inlining of it.
end
empty!(frames)
return true
Expand Down Expand Up @@ -367,13 +362,7 @@ end
function transform_result_for_cache(interp::AbstractInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
inferred_result = result.src
if inferred_result isa OptimizationState{typeof(interp)}
# TODO respect must_be_codeinf setting here?
result.src = inferred_result = ir_to_codeinf!(inferred_result)
end
if inferred_result isa CodeInfo
inferred_result.min_world = first(valid_worlds)
inferred_result.max_world = last(valid_worlds)
inferred_result = maybe_compress_codeinfo(interp, linfo, inferred_result)
end
# The global cache can only handle objects that codegen understands
Expand Down
2 changes: 1 addition & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
else {
*src_out = jl_type_infer(mi, world, 0);
if (*src_out) {
codeinst = jl_get_method_inferred(mi, (*src_out)->rettype, (*src_out)->min_world, (*src_out)->max_world);
codeinst = jl_get_codeinst_for_src(mi, *src_out);
if ((*src_out)->inferred) {
jl_value_t *null = nullptr;
jl_atomic_cmpswap_relaxed(&codeinst->inferred, &null, jl_nothing);
Expand Down
10 changes: 10 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,16 @@ JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(
return codeinst;
}

JL_DLLEXPORT jl_code_instance_t *jl_get_codeinst_for_src(
jl_method_instance_t *mi JL_PROPAGATES_ROOT, jl_code_info_t *src)
{
// TODO: copy backedges from src to mi
size_t max_world = src->max_world;
if (max_world >= jl_atomic_load_acquire(&jl_world_counter))
max_world = ~(size_t)0;
return jl_get_method_inferred(mi, src->rettype, src->min_world, max_world);
}

JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
Expand Down
3 changes: 2 additions & 1 deletion src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
// see if it is inferred, or try to infer it for ourself.
// (but don't bother with typeinf on macros or toplevel thunks)
src = jl_type_infer(mi, world, 0);
codeinst = nullptr;
}
}
jl_code_instance_t *compiled = jl_method_compiled(mi, world);
Expand All @@ -515,7 +516,7 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
}
else if (src && jl_is_code_info(src)) {
if (!codeinst) {
codeinst = jl_get_method_inferred(mi, src->rettype, src->min_world, src->max_world);
codeinst = jl_get_codeinst_for_src(mi, src);
if (src->inferred) {
jl_value_t *null = nullptr;
jl_atomic_cmpswap_relaxed(&codeinst->inferred, &null, jl_nothing);
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@
XX(jl_get_JIT) \
XX(jl_get_julia_bin) \
XX(jl_get_julia_bindir) \
XX(jl_get_method_inferred) \
XX(jl_get_module_compile) \
XX(jl_get_module_infer) \
XX(jl_get_module_of_binding) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,8 @@ JL_DLLEXPORT jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t
JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(
jl_method_instance_t *mi JL_PROPAGATES_ROOT, jl_value_t *rettype,
size_t min_world, size_t max_world);
JL_DLLEXPORT jl_code_instance_t *jl_get_codeinst_for_src(
jl_method_instance_t *mi JL_PROPAGATES_ROOT, jl_code_info_t *src);
jl_method_instance_t *jl_get_unspecialized_from_mi(jl_method_instance_t *method JL_PROPAGATES_ROOT);
jl_method_instance_t *jl_get_unspecialized(jl_method_t *def JL_PROPAGATES_ROOT);

Expand Down

0 comments on commit bde62ad

Please sign in to comment.