Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for invalid caching of CodeInfo from typeinf_ext #51872

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 25 additions & 36 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,28 +218,28 @@ function typeinf(interp::NativeInterpreter, frame::InferenceState)
end
typeinf(interp::AbstractInterpreter, frame::InferenceState) = _typeinf(interp, frame)

function finish!(interp::AbstractInterpreter, caller::InferenceResult)
# If we didn't transform the src for caching, we may have to transform
# it anyway for users like typeinf_ext. Do that here.
opt = caller.src
if opt isa OptimizationState{typeof(interp)} # implies `may_optimize(interp) === true`
if opt.ir !== nothing
if caller.must_be_codeinf
caller.src = ir_to_codeinf!(opt)
elseif is_inlineable(opt.src)
# TODO: If the CFG is too big, inlining becomes more expensive and if we're going to
# use this IR over and over, it's worth simplifying it. Round trips through
# CodeInstance do this implicitly, since they recompute the CFG, so try to
# match that behavior here.
# ir = cfg_simplify!(opt.ir)
caller.src = opt.ir
else
# Not cached and not inlineable - drop the ir
caller.src = nothing
end
end
function finish!(interp::AbstractInterpreter, caller::InferenceState)
result = caller.result
valid_worlds = result.valid_worlds
if last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(result, caller.stmt_edges[1])
end
opt = result.src
if opt isa OptimizationState && result.must_be_codeinf
result.src = opt = ir_to_codeinf!(opt)
end
return caller.src
if opt isa CodeInfo
opt.min_world = first(valid_worlds)
opt.max_world = last(valid_worlds)
Comment on lines +234 to +235
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is the critical change that allows typeinf_ext to always use CodeInfo with correct information valid for caching.

caller.src = opt
else
# In this case caller.src is invalid for clients (such as typeinf_ext) to use
# but that is what !must_be_codeinf permits
# This is hopefully unreachable when must_be_codeinf is true
end
return
end

function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
Expand All @@ -266,17 +266,12 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
end
for caller in frames
(; result ) = caller
valid_worlds = result.valid_worlds
if last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(result, caller.stmt_edges[1])
end
finish!(caller.interp, caller)
if caller.cached
cache_result!(caller.interp, result)
cache_result!(caller.interp, caller.result)
end
finish!(caller.interp, result)
# n.b. We do not drop result.src here, even though that wastes memory while it is still in the local cache
# since the user might have requested call-site inlining of it.
end
empty!(frames)
return true
Expand Down Expand Up @@ -367,13 +362,7 @@ end
function transform_result_for_cache(interp::AbstractInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
inferred_result = result.src
if inferred_result isa OptimizationState{typeof(interp)}
# TODO respect must_be_codeinf setting here?
result.src = inferred_result = ir_to_codeinf!(inferred_result)
end
if inferred_result isa CodeInfo
inferred_result.min_world = first(valid_worlds)
inferred_result.max_world = last(valid_worlds)
Comment on lines -375 to -376
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great to see this function now just focused on code transformation without modifying some additional information to make it valid for caching. Especially this part was very specific to the native compilation and so wasn't necessarily required for external abstract interpreters.

Copy link
Sponsor Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that was my thinking too. It feels like a bit of a hack to pull the final result out of the InferenceFrame object, but that is what our tools currently expect, so we do need to make sure the final results get populated there. This is still not fully correct, as we have too much of the logic for caching implemented only in Julia instead of having all customers call jl_get_codeinst_for_src and let that function deal with extracting all of the backedges, world ages, and IPO information (which currently happens only when Julia/inference is allocating a cache entry, but not when C/codegen is choosing to allocate a cache entry).

inferred_result = maybe_compress_codeinfo(interp, linfo, inferred_result)
end
# The global cache can only handle objects that codegen understands
Expand Down
2 changes: 1 addition & 1 deletion src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ static void jl_ci_cache_lookup(const jl_cgparams_t &cgparams, jl_method_instance
else {
*src_out = jl_type_infer(mi, world, 0);
if (*src_out) {
codeinst = jl_get_method_inferred(mi, (*src_out)->rettype, (*src_out)->min_world, (*src_out)->max_world);
codeinst = jl_get_codeinst_for_src(mi, *src_out);
if ((*src_out)->inferred) {
jl_value_t *null = nullptr;
jl_atomic_cmpswap_relaxed(&codeinst->inferred, &null, jl_nothing);
Expand Down
10 changes: 10 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,16 @@ JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(
return codeinst;
}

JL_DLLEXPORT jl_code_instance_t *jl_get_codeinst_for_src(
jl_method_instance_t *mi JL_PROPAGATES_ROOT, jl_code_info_t *src)
{
// TODO: copy backedges from src to mi
size_t max_world = src->max_world;
if (max_world >= jl_atomic_load_acquire(&jl_world_counter))
max_world = ~(size_t)0;
return jl_get_method_inferred(mi, src->rettype, src->min_world, max_world);
}

JL_DLLEXPORT jl_code_instance_t *jl_new_codeinst(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
Expand Down
3 changes: 2 additions & 1 deletion src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
// see if it is inferred, or try to infer it for ourself.
// (but don't bother with typeinf on macros or toplevel thunks)
src = jl_type_infer(mi, world, 0);
codeinst = nullptr;
}
}
jl_code_instance_t *compiled = jl_method_compiled(mi, world);
Expand All @@ -515,7 +516,7 @@ jl_code_instance_t *jl_generate_fptr_impl(jl_method_instance_t *mi JL_PROPAGATES
}
else if (src && jl_is_code_info(src)) {
if (!codeinst) {
codeinst = jl_get_method_inferred(mi, src->rettype, src->min_world, src->max_world);
codeinst = jl_get_codeinst_for_src(mi, src);
if (src->inferred) {
jl_value_t *null = nullptr;
jl_atomic_cmpswap_relaxed(&codeinst->inferred, &null, jl_nothing);
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@
XX(jl_get_JIT) \
XX(jl_get_julia_bin) \
XX(jl_get_julia_bindir) \
XX(jl_get_method_inferred) \
XX(jl_get_module_compile) \
XX(jl_get_module_infer) \
XX(jl_get_module_of_binding) \
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,8 @@ JL_DLLEXPORT jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t
JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(
jl_method_instance_t *mi JL_PROPAGATES_ROOT, jl_value_t *rettype,
size_t min_world, size_t max_world);
JL_DLLEXPORT jl_code_instance_t *jl_get_codeinst_for_src(
jl_method_instance_t *mi JL_PROPAGATES_ROOT, jl_code_info_t *src);
jl_method_instance_t *jl_get_unspecialized_from_mi(jl_method_instance_t *method JL_PROPAGATES_ROOT);
jl_method_instance_t *jl_get_unspecialized(jl_method_t *def JL_PROPAGATES_ROOT);

Expand Down