Skip to content

Commit

Permalink
support UnionAll types in invoke
Browse files Browse the repository at this point in the history
fixes #25341, fixes #24460, fixes #22554
  • Loading branch information
JeffBezanson committed Jan 10, 2018
1 parent 0abc263 commit 543f240
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 63 deletions.
6 changes: 3 additions & 3 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ function invoke_tfunc(@nospecialize(f), @nospecialize(types), @nospecialize(argt
return Bottom
end
ft = type_typeof(f)
types = Tuple{ft, types.parameters...}
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
argtype = Tuple{ft, argtype.parameters...}
entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt), types, sv.params.world)
if entry === nothing
Expand Down Expand Up @@ -4649,8 +4649,8 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector
invoke_tt.parameters[1] <: Tuple)
return NF
end
invoke_tt_params = invoke_tt.parameters[1].parameters
invoke_types = Tuple{ft, invoke_tt_params...}
invoke_tt = invoke_tt.parameters[1]
invoke_types = rewrap_unionall(Tuple{ft, unwrap_unionall(invoke_tt).parameters...}, invoke_tt)
invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any, UInt),
invoke_types, sv.params.world)
invoke_entry === nothing && return NF
Expand Down
11 changes: 3 additions & 8 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -903,12 +903,6 @@ JL_CALLABLE(jl_f_apply_type)

// generic function reflection ------------------------------------------------

static void jl_check_type_tuple(jl_value_t *t, jl_sym_t *name, const char *ctx)
{
if (!jl_is_tuple_type(t))
jl_type_error_rt(jl_symbol_name(name), ctx, (jl_value_t*)jl_type_type, t);
}

JL_CALLABLE(jl_f_applicable)
{
JL_NARGSV(applicable, 1);
Expand All @@ -922,11 +916,12 @@ JL_CALLABLE(jl_f_invoke)
JL_NARGSV(invoke, 2);
jl_value_t *argtypes = args[1];
JL_GC_PUSH1(&argtypes);
jl_check_type_tuple(args[1], jl_gf_name(args[0]), "invoke");
if (!jl_is_tuple_type(jl_unwrap_unionall(args[1])))
jl_type_error_rt(jl_symbol_name(jl_gf_name(args[0])), "invoke", (jl_value_t*)jl_type_type, args[1]);
if (!jl_tuple_isa(&args[2], nargs-2, (jl_datatype_t*)argtypes))
jl_error("invoke: argument type error");
args[1] = args[0]; // move function directly in front of arguments
jl_value_t *res = jl_gf_invoke((jl_tupletype_t*)argtypes, &args[1], nargs-1);
jl_value_t *res = jl_gf_invoke(argtypes, &args[1], nargs-1);
JL_GC_POP();
return res;
}
Expand Down
2 changes: 1 addition & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4523,7 +4523,7 @@ static Function *jl_cfunction_object(jl_function_t *ff, jl_value_t *declrt, jl_t
// check the cache
jl_typemap_entry_t *sf = NULL;
if (jl_cfunction_list.unknown != jl_nothing) {
sf = jl_typemap_assoc_by_type(jl_cfunction_list, (jl_tupletype_t*)cfunc_sig, NULL, /*subtype*/0, /*offs*/0, /*world*/1, /*max_world_mask*/0);
sf = jl_typemap_assoc_by_type(jl_cfunction_list, cfunc_sig, NULL, /*subtype*/0, /*offs*/0, /*world*/1, /*max_world_mask*/0);
if (sf) {
jl_value_t *v = sf->func.value;
if (v) {
Expand Down
4 changes: 2 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2568,7 +2568,7 @@ static jl_method_t *jl_lookup_method_worldset(jl_methtable_t *mt, jl_datatype_t
jl_method_t *_new;
while (1) {
entry = jl_typemap_assoc_by_type(
mt->defs, sig, NULL, /*subtype*/0, /*offs*/0, world, /*max_world_mask*/0);
mt->defs, (jl_value_t*)sig, NULL, /*subtype*/0, /*offs*/0, world, /*max_world_mask*/0);
if (!entry)
break;
_new = (jl_method_t*)entry->func.value;
Expand All @@ -2581,7 +2581,7 @@ static jl_method_t *jl_lookup_method_worldset(jl_methtable_t *mt, jl_datatype_t
// If we failed to find a method (perhaps due to method deletion),
// grab anything
entry = jl_typemap_assoc_by_type(
mt->defs, sig, NULL, /*subtype*/0, /*offs*/0, /*world*/jl_world_counter, /*max_world_mask*/(~(size_t)0) >> 1);
mt->defs, (jl_value_t*)sig, NULL, /*subtype*/0, /*offs*/0, /*world*/jl_world_counter, /*max_world_mask*/(~(size_t)0) >> 1);
assert(entry);
assert(entry->max_world != ~(size_t)0);
*max_world = entry->max_world;
Expand Down
60 changes: 31 additions & 29 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(jl_method_t *m,
assert(world >= m->min_world && "typemap lookup is corrupted");
JL_LOCK(&m->writelock);
jl_typemap_entry_t *sf =
jl_typemap_assoc_by_type(m->specializations, (jl_tupletype_t*)type, NULL, /*subtype*/0, /*offs*/0, world, /*max_world_mask*/0);
jl_typemap_assoc_by_type(m->specializations, type, NULL, /*subtype*/0, /*offs*/0, world, /*max_world_mask*/0);
if (sf && jl_is_method_instance(sf->func.value)) {
jl_method_instance_t *linfo = (jl_method_instance_t*)sf->func.value;
assert(linfo->min_world <= sf->min_world && linfo->max_world >= sf->max_world);
Expand Down Expand Up @@ -177,7 +177,7 @@ JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(jl_method_t *m,
return li;
}

JL_DLLEXPORT jl_value_t *jl_specializations_lookup(jl_method_t *m, jl_tupletype_t *type, size_t world)
JL_DLLEXPORT jl_value_t *jl_specializations_lookup(jl_method_t *m, jl_value_t *type, size_t world)
{
jl_typemap_entry_t *sf = jl_typemap_assoc_by_type(
m->specializations, type, NULL, /*subtype*/0, /*offs*/0, world, /*max_world_mask*/0);
Expand All @@ -186,7 +186,7 @@ JL_DLLEXPORT jl_value_t *jl_specializations_lookup(jl_method_t *m, jl_tupletype_
return sf->func.value;
}

JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_tupletype_t *type, size_t world)
JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_value_t *type, size_t world)
{
jl_typemap_entry_t *sf = jl_typemap_assoc_by_type(
mt->defs, type, NULL, /*subtype*/0, /*offs*/0, world, /*max_world_mask*/0);
Expand Down Expand Up @@ -548,7 +548,7 @@ jl_value_t *jl_nth_slot_type(jl_value_t *sig, size_t i)
// after intersection, the argument tuple type needs to be corrected to reflect the signature match
// that occurred, if the arguments contained a Type but the signature matched on the kind
// if sharp_match is returned as false, this tt may have matched only because of bug in subtyping
static jl_tupletype_t *join_tsig(jl_tupletype_t *tt, jl_tupletype_t *sig, int *sharp_match)
static jl_tupletype_t *join_tsig(jl_tupletype_t *tt, jl_value_t *sig, int *sharp_match)
{
jl_svec_t *newparams = NULL;
JL_GC_PUSH1(&newparams);
Expand Down Expand Up @@ -1070,10 +1070,10 @@ static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_datatype
jl_method_instance_t *nf = NULL;
JL_GC_PUSH4(&env, &entry, &func, &sig);

entry = jl_typemap_assoc_by_type(mt->defs, tt, &env, /*subtype*/1, /*offs*/0, world, /*max_world_mask*/0);
entry = jl_typemap_assoc_by_type(mt->defs, (jl_value_t*)tt, &env, /*subtype*/1, /*offs*/0, world, /*max_world_mask*/0);
if (entry != NULL) {
jl_method_t *m = entry->func.method;
if (!jl_has_call_ambiguities(tt, m)) {
if (!jl_has_call_ambiguities((jl_value_t*)tt, m)) {
#ifdef TRACE_COMPILE
if (!jl_has_free_typevars((jl_value_t*)tt)) {
jl_printf(JL_STDERR, "precompile(");
Expand All @@ -1082,7 +1082,7 @@ static jl_method_instance_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_datatype
}
#endif
int sharp_match;
sig = join_tsig(tt, (jl_tupletype_t*)m->sig, &sharp_match);
sig = join_tsig(tt, m->sig, &sharp_match);
if (!mt_cache) {
nf = jl_specializations_get_linfo(m, (jl_value_t*)sig, env, world);
}
Expand Down Expand Up @@ -1167,7 +1167,7 @@ static int check_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_
// (if type-morespecific made a mistake, this also might end up finding
// that isect == type or isect == sig and return the original match)
jl_typemap_entry_t *l = jl_typemap_assoc_by_type(
map, (jl_tupletype_t*)isect, NULL, /*subtype*/0, /*offs*/0,
map, isect, NULL, /*subtype*/0, /*offs*/0,
closure->newentry->min_world, /*max_world_mask*/0);
if (l != NULL) // ok, intersection is covered
return 1;
Expand Down Expand Up @@ -1637,15 +1637,15 @@ jl_tupletype_t *arg_type_tuple(jl_value_t **args, size_t nargs)
jl_method_instance_t *jl_method_lookup_by_type(jl_methtable_t *mt, jl_tupletype_t *types,
int cache, int allow_exec, size_t world)
{
jl_typemap_entry_t *entry = jl_typemap_assoc_by_type(mt->cache, types, NULL, /*subtype*/1, jl_cachearg_offset(mt), world, /*max_world_mask*/0);
jl_typemap_entry_t *entry = jl_typemap_assoc_by_type(mt->cache, (jl_value_t*)types, NULL, /*subtype*/1, jl_cachearg_offset(mt), world, /*max_world_mask*/0);
if (entry) {
jl_method_instance_t *linfo = (jl_method_instance_t*)entry->func.value;
assert(linfo->min_world <= entry->min_world && linfo->max_world >= entry->max_world &&
"typemap consistency error: MethodInstance doesn't apply to full range of its entry");
return linfo;
}
JL_LOCK(&mt->writelock);
entry = jl_typemap_assoc_by_type(mt->cache, types, NULL, /*subtype*/1, jl_cachearg_offset(mt), world, /*max_world_mask*/0);
entry = jl_typemap_assoc_by_type(mt->cache, (jl_value_t*)types, NULL, /*subtype*/1, jl_cachearg_offset(mt), world, /*max_world_mask*/0);
if (entry) {
jl_method_instance_t *linfo = (jl_method_instance_t*)entry->func.value;
assert(linfo->min_world <= entry->min_world && linfo->max_world >= entry->max_world &&
Expand Down Expand Up @@ -1808,12 +1808,12 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types, size_t world
jl_svec_t *env = (jl_svec_t*)jl_svecref(match, 1);
jl_tupletype_t *ti = (jl_tupletype_t*)jl_unwrap_unionall(jl_svecref(match, 0));
jl_method_instance_t *nf = NULL;
if (ti == types && !jl_has_call_ambiguities(types, m)) {
if (ti == types && !jl_has_call_ambiguities((jl_value_t*)types, m)) {
jl_datatype_t *dt = jl_first_argument_datatype(jl_unwrap_unionall((jl_value_t*)types));
assert(jl_is_datatype(dt));
jl_methtable_t *mt = dt->name->mt;
int sharp_match;
sig = join_tsig(ti, (jl_tupletype_t*)m->sig, &sharp_match);
sig = join_tsig(ti, m->sig, &sharp_match);
if (sharp_match) {
JL_LOCK(&mt->writelock);
nf = cache_method(mt, &mt->cache, (jl_value_t*)mt, sig, ti, m, world, env, /*allow_exec*/1);
Expand Down Expand Up @@ -1872,7 +1872,7 @@ JL_DLLEXPORT jl_value_t *jl_get_spec_lambda(jl_tupletype_t *types, size_t world)
}

// see if a call to m with computed from `types` is ambiguous
JL_DLLEXPORT int jl_is_call_ambiguous(jl_tupletype_t *types, jl_method_t *m)
JL_DLLEXPORT int jl_is_call_ambiguous(jl_value_t *types, jl_method_t *m)
{
if (m->ambig == jl_nothing)
return 0;
Expand All @@ -1886,21 +1886,22 @@ JL_DLLEXPORT int jl_is_call_ambiguous(jl_tupletype_t *types, jl_method_t *m)

// see if a call to m with a subtype of `types` might be ambiguous
// if types is from a call signature (approximated by isleaftype), this is the same as jl_is_call_ambiguous above
JL_DLLEXPORT int jl_has_call_ambiguities(jl_tupletype_t *types, jl_method_t *m)
JL_DLLEXPORT int jl_has_call_ambiguities(jl_value_t *types, jl_method_t *m)
{
if (m->ambig == jl_nothing)
return 0;
for (size_t i = 0; i < jl_array_len(m->ambig); i++) {
jl_method_t *mambig = (jl_method_t*)jl_array_ptr_ref(m->ambig, i);
if (!jl_has_empty_intersection((jl_value_t*)mambig->sig, (jl_value_t*)types))
if (!jl_has_empty_intersection(mambig->sig, types))
return 1;
}
return 0;
}

// add type of `f` to front of argument tuple type
jl_tupletype_t *jl_argtype_with_function(jl_function_t *f, jl_tupletype_t *types)
jl_value_t *jl_argtype_with_function(jl_function_t *f, jl_value_t *types0)
{
jl_value_t *types = jl_unwrap_unionall(types0);
size_t l = jl_nparams(types);
jl_value_t *tt = (jl_value_t*)jl_alloc_svec(1+l);
size_t i;
Expand All @@ -1912,8 +1913,9 @@ jl_tupletype_t *jl_argtype_with_function(jl_function_t *f, jl_tupletype_t *types
for(i=0; i < l; i++)
jl_svecset(tt, i+1, jl_tparam(types,i));
tt = (jl_value_t*)jl_apply_tuple_type((jl_svec_t*)tt);
tt = jl_rewrap_unionall(tt, types0);
JL_GC_POP();
return (jl_tupletype_t*)tt;
return tt;
}

#ifdef JL_TRACE
Expand Down Expand Up @@ -2083,9 +2085,9 @@ JL_DLLEXPORT jl_value_t *jl_apply_generic(jl_value_t **args, uint32_t nargs)
return verify_type(res);
}

JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_datatype_t *types, size_t world)
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
{
jl_methtable_t *mt = ((jl_datatype_t*)jl_tparam0(types))->name->mt;
jl_methtable_t *mt = jl_first_argument_datatype(types)->name->mt;
jl_svec_t *env = jl_emptysvec;
JL_GC_PUSH1(&env);
jl_typemap_entry_t *entry = jl_typemap_assoc_by_type(
Expand All @@ -2107,21 +2109,21 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_datatype_t *types, size_t world)
// every definition has its own private method table for this purpose.
//
// NOTE: assumes argument type is a subtype of the lookup type.
jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs)
jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
{
size_t world = jl_get_ptls_states()->world_age;
jl_svec_t *tpenv = jl_emptysvec;
jl_tupletype_t *tt = NULL;
jl_tupletype_t *types = NULL;
jl_value_t *types = NULL;
jl_tupletype_t *sig = NULL;
JL_GC_PUSH4(&types, &tpenv, &sig, &tt);
jl_value_t *gf = args[0];
types = (jl_datatype_t*)jl_argtype_with_function(gf, (jl_tupletype_t*)types0);
types = jl_argtype_with_function(gf, types0);
jl_methtable_t *mt = jl_gf_mtable(gf);
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_gf_invoke_lookup(types, world);

if ((jl_value_t*)entry == jl_nothing) {
jl_method_error_bare(gf, (jl_value_t*)types0, world);
jl_method_error_bare(gf, types0, world);
// unreachable
}

Expand Down Expand Up @@ -2154,7 +2156,7 @@ jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs
method->invokes.unknown = jl_nothing;

int sharp_match;
sig = join_tsig(tt, (jl_tupletype_t*)method->sig, &sharp_match);
sig = join_tsig(tt, method->sig, &sharp_match);
mfunc = cache_method(mt, &method->invokes, entry->func.value, sig, tt, method, world, tpenv, 1);
}
JL_UNLOCK(&method->writelock);
Expand Down Expand Up @@ -2197,10 +2199,10 @@ static int tupletype_has_datatype(jl_tupletype_t *tt, tupletype_stack_t *stack)

JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
jl_typemap_entry_t *entry,
jl_tupletype_t *tt,
jl_value_t *tt,
size_t world)
{
if (!jl_is_leaf_type((jl_value_t*)tt) || tupletype_has_datatype(tt, NULL))
if (!jl_is_leaf_type((jl_value_t*)tt) || tupletype_has_datatype((jl_tupletype_t*)tt, NULL))
return jl_nothing;

jl_method_t *method = entry->func.method;
Expand Down Expand Up @@ -2228,7 +2230,7 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
JL_GC_PUSH2(&tpenv, &sig);
if (jl_is_unionall(entry->sig)) {
jl_value_t *ti =
jl_type_intersection_env((jl_value_t*)tt, (jl_value_t*)entry->sig, &tpenv);
jl_type_intersection_env(tt, (jl_value_t*)entry->sig, &tpenv);
assert(ti != (jl_value_t*)jl_bottom_type);
(void)ti;
}
Expand All @@ -2237,9 +2239,9 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
method->invokes.unknown = jl_nothing;

int sharp_match;
sig = join_tsig(tt, (jl_tupletype_t*)method->sig, &sharp_match);
sig = join_tsig((jl_tupletype_t*)tt, method->sig, &sharp_match);
jl_method_instance_t *mfunc = cache_method(mt, &method->invokes, entry->func.value,
sig, tt, method, world, tpenv, 1);
sig, (jl_tupletype_t*)tt, method, world, tpenv, 1);
JL_GC_POP();
JL_UNLOCK(&method->writelock);
return (jl_value_t*)mfunc;
Expand Down
10 changes: 5 additions & 5 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ STATIC_INLINE jl_value_t *jl_call_method_internal(jl_method_instance_t *meth, jl
return jl_call_fptr_internal(&fptr, meth, args, nargs);
}

jl_tupletype_t *jl_argtype_with_function(jl_function_t *f, jl_tupletype_t *types);
jl_value_t *jl_argtype_with_function(jl_function_t *f, jl_value_t *types);

JL_DLLEXPORT jl_value_t *jl_apply_2va(jl_value_t *f, jl_value_t **args, uint32_t nargs);

Expand Down Expand Up @@ -516,7 +516,7 @@ int jl_is_toplevel_only_expr(jl_value_t *e);
jl_value_t *jl_call_scm_on_ast(const char *funcname, jl_value_t *expr, jl_module_t *inmodule);

jl_method_instance_t *jl_method_lookup(jl_methtable_t *mt, jl_value_t **args, size_t nargs, int cache, size_t world);
jl_value_t *jl_gf_invoke(jl_tupletype_t *types, jl_value_t **args, size_t nargs);
jl_value_t *jl_gf_invoke(jl_value_t *types, jl_value_t **args, size_t nargs);
jl_method_instance_t *jl_lookup_generic(jl_value_t **args, uint32_t nargs, uint32_t callsite, size_t world);
JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, int lim, int include_ambiguous,
size_t world, size_t *min_valid, size_t *max_valid);
Expand Down Expand Up @@ -638,10 +638,10 @@ JL_DLLEXPORT jl_array_t *jl_idtable_rehash(jl_array_t *a, size_t newsz);

JL_DLLEXPORT jl_methtable_t *jl_new_method_table(jl_sym_t *name, jl_module_t *module);
jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types, size_t world);
JL_DLLEXPORT int jl_has_call_ambiguities(jl_tupletype_t *types, jl_method_t *m);
JL_DLLEXPORT int jl_has_call_ambiguities(jl_value_t *types, jl_method_t *m);
jl_method_instance_t *jl_get_specialized(jl_method_t *m, jl_value_t *types, jl_svec_t *sp);
int jl_is_rettype_inferred(jl_method_instance_t *li);
JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_tupletype_t *type, size_t world);
JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt, jl_value_t *type, size_t world);
JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(jl_method_t *m, jl_value_t *type, jl_svec_t *sparams, size_t world);
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_method_instance_t *caller);
JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *typ, jl_value_t *caller);
Expand Down Expand Up @@ -926,7 +926,7 @@ jl_typemap_entry_t *jl_typemap_insert(union jl_typemap_t *cache, jl_value_t *par
jl_value_t **overwritten);

jl_typemap_entry_t *jl_typemap_assoc_by_type(
union jl_typemap_t ml_or_cache, jl_tupletype_t *types, jl_svec_t **penv,
union jl_typemap_t ml_or_cache, jl_value_t *types, jl_svec_t **penv,
int8_t subtype, int8_t offs, size_t world, size_t max_world_mask);
jl_typemap_entry_t *jl_typemap_level_assoc_exact(jl_typemap_level_t *cache, jl_value_t **args, size_t n, int8_t offs, size_t world);
jl_typemap_entry_t *jl_typemap_entry_assoc_exact(jl_typemap_entry_t *mn, jl_value_t **args, size_t n, size_t world);
Expand Down
Loading

0 comments on commit 543f240

Please sign in to comment.