Skip to content

Commit

Permalink
ensure jl_compilation_sig does not narrow Vararg (#48152)
Browse files Browse the repository at this point in the history
Some code cleanup, and an early exit path that avoids trying to create a
compilation signature from something that cannot be turned into one.
Previously we might try a little too hard to make one, even if it meant
we ignored that it was expected to be Varargs.

Fix #48085

(cherry picked from commit 45c81b1)
  • Loading branch information
vtjnash authored and KristofferC committed Jan 10, 2023
1 parent acb7e09 commit 5547468
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 24 deletions.
78 changes: 54 additions & 24 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,15 +584,12 @@ jl_value_t *jl_nth_slot_type(jl_value_t *sig, size_t i) JL_NOTSAFEPOINT
{
sig = jl_unwrap_unionall(sig);
size_t len = jl_nparams(sig);
if (len == 0)
return NULL;
if (i < len-1)
return jl_tparam(sig, i);
if (jl_is_vararg(jl_tparam(sig, len-1)))
return jl_unwrap_vararg(jl_tparam(sig, len-1));
if (i == len-1)
return jl_tparam(sig, i);
return NULL;
jl_value_t *p = jl_tparam(sig, len-1);
if (jl_is_vararg(p))
p = jl_unwrap_vararg(p);
return p;
}

// if concrete_match returns false, the sig may specify `Type{T::DataType}`, while the `tt` contained DataType
Expand Down Expand Up @@ -660,31 +657,62 @@ static jl_value_t *ml_matches(jl_methtable_t *mt,

// get the compilation signature specialization for this method
static void jl_compilation_sig(
jl_tupletype_t *const tt, // the original tupletype of the call : this is expected to be a relative simple type (no Varags, Union, UnionAll, etc.)
jl_tupletype_t *const tt, // the original tupletype of the call (or DataType from precompile)
jl_svec_t *sparams,
jl_method_t *definition,
intptr_t nspec,
// output:
jl_svec_t **const newparams JL_REQUIRE_ROOTED_SLOT)
{
assert(jl_is_tuple_type(tt));
jl_value_t *decl = definition->sig;
size_t nargs = definition->nargs; // == jl_nparams(jl_unwrap_unionall(decl));

if (definition->generator) {
// staged functions aren't optimized
// so assume the caller was intelligent about calling us
return;
}
if (definition->sig == (jl_value_t*)jl_anytuple_type && jl_atomic_load_relaxed(&definition->unspecialized)) {

if (decl == (jl_value_t*)jl_anytuple_type && jl_atomic_load_relaxed(&definition->unspecialized)) {
*newparams = jl_anytuple_type->parameters; // handle builtin methods
return;
}

jl_value_t *decl = definition->sig;
assert(jl_is_tuple_type(tt));
// some early sanity checks
size_t i, np = jl_nparams(tt);
size_t nargs = definition->nargs; // == jl_nparams(jl_unwrap_unionall(decl));
switch (jl_va_tuple_kind((jl_datatype_t*)decl)) {
case JL_VARARG_NONE:
if (jl_is_va_tuple(tt))
// odd
return;
if (np != nargs)
// there are not enough input parameters to make this into a compilation sig
return;
break;
case JL_VARARG_INT:
case JL_VARARG_BOUND:
if (jl_is_va_tuple(tt))
// the length needed is not known, but required for compilation
return;
if (np < nargs - 1)
// there are not enough input parameters to make this into a compilation sig
return;
break;
case JL_VARARG_UNBOUND:
if (np < nspec && jl_is_va_tuple(tt))
// there are insufficient given parameters for jl_isa_compileable_sig now to like this type
// (there were probably fewer methods defined when we first selected this signature)
return;
break;
}

jl_value_t *type_i = NULL;
JL_GC_PUSH1(&type_i);
for (i = 0; i < np; i++) {
jl_value_t *elt = jl_tparam(tt, i);
if (jl_is_vararg(elt))
elt = jl_unwrap_vararg(elt);
jl_value_t *decl_i = jl_nth_slot_type(decl, i);
type_i = jl_rewrap_unionall(decl_i, decl);
size_t i_arg = (i < nargs - 1 ? i : nargs - 1);
Expand Down Expand Up @@ -732,16 +760,14 @@ static void jl_compilation_sig(
if (!jl_has_free_typevars(decl_i) && !jl_is_kind(decl_i)) {
if (decl_i != elt) {
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
// n.b. it is possible here that !(elt <: decl_i), if elt was something unusual from intersection
// so this might narrow the result slightly, though still being compatible with the declared signature
jl_svecset(*newparams, i, (jl_value_t*)decl_i);
}
continue;
}
}

if (jl_is_vararg(elt)) {
continue;
}

if (jl_types_equal(elt, (jl_value_t*)jl_type_type)) { // elt == Type{T} where T
// not triggered for isdispatchtuple(tt), this attempts to handle
// some cases of adapting a random signature into a compilation signature
Expand Down Expand Up @@ -827,7 +853,7 @@ static void jl_compilation_sig(
// in general, here we want to find the biggest type that's not a
// supertype of any other method signatures. so far we are conservative
// and the types we find should be bigger.
if (jl_nparams(tt) >= nspec && jl_va_tuple_kind((jl_datatype_t*)decl) == JL_VARARG_UNBOUND) {
if (np >= nspec && jl_va_tuple_kind((jl_datatype_t*)decl) == JL_VARARG_UNBOUND) {
if (!*newparams) *newparams = tt->parameters;
type_i = jl_svecref(*newparams, nspec - 2);
// if all subsequent arguments are subtypes of type_i, specialize
Expand Down Expand Up @@ -2075,7 +2101,9 @@ JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *
if (ambig != NULL)
*ambig = 0;
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types);
if (jl_is_tuple_type(unw) && (unw == (jl_value_t*)jl_emptytuple_type || jl_tparam0(unw) == jl_bottom_type))
if (!jl_is_tuple_type(unw))
return (jl_value_t*)jl_an_empty_vec_any;
if (unw == (jl_value_t*)jl_emptytuple_type || jl_tparam0(unw) == jl_bottom_type)
return (jl_value_t*)jl_an_empty_vec_any;
if (mt == jl_nothing)
mt = (jl_value_t*)jl_method_table_for(unw);
Expand Down Expand Up @@ -2172,8 +2200,8 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
if (codeinst)
return codeinst;

// if mi has a better (wider) signature for compilation use that instead
// and just copy it here for caching
// if mi has a better (wider) signature preferred for compilation use that
// instead and just copy it here for caching
jl_method_instance_t *mi2 = jl_normalize_to_compilable_mi(mi);
if (mi2 != mi) {
jl_code_instance_t *codeinst2 = jl_compile_method_internal(mi2, world);
Expand Down Expand Up @@ -2362,7 +2390,7 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
jl_method_instance_t *jl_normalize_to_compilable_mi(jl_method_instance_t *mi JL_PROPAGATES_ROOT)
{
jl_method_t *def = mi->def.method;
if (!jl_is_method(def))
if (!jl_is_method(def) || !jl_is_datatype(mi->specTypes))
return mi;
jl_methtable_t *mt = jl_method_get_table(def);
if ((jl_value_t*)mt == jl_nothing)
Expand Down Expand Up @@ -2444,7 +2472,7 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES

// Get a MethodInstance for a precompile() call. This uses a special kind of lookup that
// tries to find a method for which the requested signature is compileable.
jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
static jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
{
if (jl_has_free_typevars((jl_value_t*)types))
return NULL; // don't poison the cache due to a malformed query
Expand All @@ -2467,7 +2495,7 @@ jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types J
if (n == 1) {
match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
}
else {
else if (jl_is_datatype(types)) {
// first, select methods for which `types` is compileable
size_t count = 0;
for (i = 0; i < n; i++) {
Expand Down Expand Up @@ -2838,7 +2866,9 @@ JL_DLLEXPORT jl_value_t *jl_apply_generic(jl_value_t *F, jl_value_t **args, uint
static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid)
{
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types);
if (jl_is_tuple_type(unw) && jl_tparam0(unw) == jl_bottom_type)
if (!jl_is_tuple_type(unw))
return NULL;
if (jl_tparam0(unw) == jl_bottom_type)
return NULL;
if (mt == jl_nothing)
mt = (jl_value_t*)jl_method_table_for(unw);
Expand Down
4 changes: 4 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,7 @@ f47247(a::Ref{Int}, b::Nothing) = setfield!(a, :x, b)
f(x) = Core.bitcast(UInt64, x)
@test occursin("llvm.trap", get_llvm(f, Tuple{Union{}}))
end

f48085(@nospecialize x...) = length(x)
@test Core.Compiler.get_compileable_sig(which(f48085, (Vararg{Any},)), Tuple{typeof(f48085), Vararg{Int}}, Core.svec()) === nothing
@test Core.Compiler.get_compileable_sig(which(f48085, (Vararg{Any},)), Tuple{typeof(f48085), Int, Vararg{Int}}, Core.svec()) === Tuple{typeof(f48085), Any, Vararg{Any}}

0 comments on commit 5547468

Please sign in to comment.