Skip to content

Commit

Permalink
Support method deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Dec 12, 2017
1 parent a8054ed commit 3357f95
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 3 deletions.
16 changes: 16 additions & 0 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,22 @@ function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false)
return true
end

"""
delete_method(m::Method)
Make method `m` uncallable and force recompilation of any methods that use(d) it.
"""
function delete_method(m::Method)
ccall(:jl_method_table_disable, Void, (Any, Any), MethodTable(m), m)
end

MethodTable(m::Method) = get_methodtable(m.sig)

get_methodtable(u::UnionAll) = get_methodtable(u.body)
get_methodtable(sig) = _get_methodtable(sig.parameters[1])
_get_methodtable(u::UnionAll) = _get_methodtable(u.body)
_get_methodtable(f) = f.name.mt

"""
has_bottom_parameter(t) -> Bool
Expand Down
77 changes: 74 additions & 3 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,8 @@ static int check_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_
closure->after = 1;
return 1;
}
if (oldentry->max_world < ~(size_t)0)
return 1;
union jl_typemap_t map = closure->defs;
jl_tupletype_t *type = (jl_tupletype_t*)closure->match.type;
jl_method_t *m = closure->newentry->func.method;
Expand Down Expand Up @@ -1212,7 +1214,7 @@ static int check_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_
return 1;
}

static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_entry_t *newentry)
static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_entry_t *newentry, jl_typemap_intersection_visitor_fptr fptr)
{
jl_tupletype_t *type = newentry->sig;
jl_tupletype_t *ttypes = (jl_tupletype_t*)jl_unwrap_unionall((jl_value_t*)type);
Expand All @@ -1226,7 +1228,7 @@ static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_e
va = NULL;
}
struct ambiguous_matches_env env;
env.match.fptr = check_ambiguous_visitor;
env.match.fptr = fptr;
env.match.type = (jl_value_t*)type;
env.match.va = va;
env.match.ti = NULL;
Expand All @@ -1241,6 +1243,47 @@ static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_e
return env.shadowed;
}

static int check_disabled_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_intersection_env *closure0)
{
struct ambiguous_matches_env *closure = container_of(closure0, struct ambiguous_matches_env, match);
if (oldentry == closure->newentry) {
closure->after = 1;
return 1;
}
if (!closure->after || oldentry->max_world < ~(size_t)0) // the second condition prevents us from confusion in multiple cycles of add/delete
return 1;
jl_tupletype_t *sig = oldentry->sig;
jl_value_t *isect = closure->match.ti;
if (closure->shadowed == NULL)
closure->shadowed = (jl_value_t*)jl_alloc_vec_any(0);

int i, l = jl_array_len(closure->shadowed);
for (i = 0; i < l; i++) {
jl_method_t *mth = (jl_method_t*)jl_array_ptr_ref(closure->shadowed, i);
jl_value_t *isect2 = jl_type_intersection(mth->sig, (jl_value_t*)sig);
// see if the intersection was covered by precisely the disabled method
// that means we now need to record the ambiguity
if (jl_types_equal(isect, isect2)) {
jl_method_t *mambig = mth;
jl_method_t *m = oldentry->func.method;
if (m->ambig == jl_nothing) {
m->ambig = (jl_value_t*) jl_alloc_vec_any(0);
jl_gc_wb(m, m->ambig);
}
if (mambig->ambig == jl_nothing) {
mambig->ambig = (jl_value_t*) jl_alloc_vec_any(0);
jl_gc_wb(mambig, mambig->ambig);
}
jl_array_ptr_1d_push((jl_array_t*) m->ambig, (jl_value_t*) mambig);
jl_array_ptr_1d_push((jl_array_t*) mambig->ambig, (jl_value_t*) m);
}
}

jl_array_ptr_1d_push((jl_array_t*)closure->shadowed, oldentry->func.value);
return 1;
}


static void method_overwrite(jl_typemap_entry_t *newentry, jl_method_t *oldvalue)
{
// method overwritten
Expand Down Expand Up @@ -1405,6 +1448,34 @@ void jl_method_instance_delete(jl_method_instance_t *mi)
jl_uv_puts(JL_STDOUT, "<<<\n", 4);
}

static int typemap_search(jl_typemap_entry_t *entry, void *closure)
{
if ((void*)(entry->func.method) == *(jl_method_t**)closure) {
*(jl_typemap_entry_t**)closure = entry;
return 0;
}
return 1;
}

//JL_DLLEXPORT void jl_method_table_disable(jl_methtable_t *mt, jl_typemap_entry_t *methodentry)
JL_DLLEXPORT void jl_method_table_disable(jl_methtable_t *mt, jl_method_t *method)
{
struct invalidate_conflicting_env env;
env.invalidated = 0;
jl_typemap_entry_t *methodentry = (jl_typemap_entry_t*)(method);
if (jl_typemap_visitor(mt->defs, typemap_search, &methodentry))
jl_error("method not in method table");
JL_LOCK(&mt->writelock);
// Narrow the world age on the method to make it uncallable
methodentry->max_world = jl_world_counter++;
// Recompute ambiguities (deleting a more specific method might reveal ambiguities that it previously resolved)
env.max_world = methodentry->max_world;
check_ambiguous_matches(mt->defs, methodentry, check_disabled_ambiguous_visitor); // TODO: decrease repeated work?
// Invalidate the backedges
jl_typemap_visitor(methodentry->func.method->specializations, (jl_typemap_visitor_fptr)invalidate_backedges, &env);
JL_UNLOCK(&mt->writelock);
}

JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method, jl_tupletype_t *simpletype)
{
assert(jl_is_method(method));
Expand All @@ -1430,7 +1501,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method
method_overwrite(newentry, (jl_method_t*)oldvalue);
}
else {
oldvalue = check_ambiguous_matches(mt->defs, newentry);
oldvalue = check_ambiguous_matches(mt->defs, newentry, check_ambiguous_visitor);
if (mt->backedges) {
jl_value_t **backedges = (jl_value_t**)jl_array_data(mt->backedges);
size_t i, na = jl_array_len(mt->backedges);
Expand Down
126 changes: 126 additions & 0 deletions test/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,129 @@ cinfo = cinfos[]
test_similar_codeinfo(cinfo, cinfo_generated)

@test_throws ErrorException code_lowered(f22979, typeof.(x22979), false)

module MethodDeletion
using Test

# Deletion after compiling top-level call
bar1(x) = 1
bar1(x::Int) = 2
foo1(x) = bar1(x)
faz1(x) = foo1(x)
@test faz1(1) == 2
@test faz1(1.0) == 1
m = first(methods(bar1, Tuple{Int}))
Base.delete_method(m)
@test bar1(1) == 1
@test bar1(1.0) == 1
@test foo1(1) == 1
@test foo1(1.0) == 1
@test faz1(1) == 1
@test faz1(1.0) == 1

# Deletion after compiling middle-level call
bar2(x) = 1
bar2(x::Int) = 2
foo2(x) = bar2(x)
faz2(x) = foo2(x)
@test foo2(1) == 2
@test foo2(1.0) == 1
m = first(methods(bar2, Tuple{Int}))
Base.delete_method(m)
@test bar2(1.0) == 1
@test bar2(1) == 1
@test foo2(1) == 1
@test foo2(1.0) == 1
@test faz2(1) == 1
@test faz2(1.0) == 1

# Deletion after compiling low-level call
bar3(x) = 1
bar3(x::Int) = 2
foo3(x) = bar3(x)
faz3(x) = foo3(x)
@test bar3(1) == 2
@test bar3(1.0) == 1
m = first(methods(bar3, Tuple{Int}))
Base.delete_method(m)
@test bar3(1) == 1
@test bar3(1.0) == 1
@test foo3(1) == 1
@test foo3(1.0) == 1
@test faz3(1) == 1
@test faz3(1.0) == 1

# Deletion before any compilation
bar4(x) = 1
bar4(x::Int) = 2
foo4(x) = bar4(x)
faz4(x) = foo4(x)
m = first(methods(bar4, Tuple{Int}))
Base.delete_method(m)
@test bar4(1) == 1
@test bar4(1.0) == 1
@test foo4(1) == 1
@test foo4(1.0) == 1
@test faz4(1) == 1
@test faz4(1.0) == 1

# Methods with keyword arguments
fookw(x; direction=:up) = direction
fookw(y::Int) = 2
@test fookw("string") == :up
@test fookw(1) == 2
m = collect(methods(fookw))[2]
Base.delete_method(m)
@test fookw(1) == 2
@test_throws MethodError fookw("string")

# functions with many methods
types = (Float64, Int32, String)
for T1 in types, T2 in types, T3 in types
@eval foomany(x::$T1, y::$T2, z::$T3) = y
end
@test foomany(Int32(5), "hello", 3.2) == "hello"
m = first(methods(foomany, Tuple{Int32, String, Float64}))
Base.delete_method(m)
@test_throws MethodError foomany(Int32(5), "hello", 3.2)

struct EmptyType end
Base.convert(::Type{EmptyType}, x::Integer) = EmptyType()
m = first(methods(convert, Tuple{Type{EmptyType}, Integer}))
Base.delete_method(m)
@test_throws MethodError convert(EmptyType, 1)

# parametric methods
parametric(A::Array{T,N}, i::Vararg{Int,N}) where {T,N} = N
@test parametric(rand(2,2), 1, 1) == 2
m = first(methods(parametric))
Base.delete_method(m)
@test_throws MethodError parametric(rand(2,2), 1, 1)

# Deletion and ambiguity detection
foo(::Int, ::Int) = 1
foo(::Real, ::Int) = 2
foo(::Int, ::Real) = 3
@test all(map(g->g.ambig==nothing, methods(foo)))
Base.delete_method(first(methods(foo)))
@test !all(map(g->g.ambig==nothing, methods(foo)))
@test_throws MethodError foo(1, 1)
foo(::Int, ::Int) = 1
foo(1, 1)
@test map(g->g.ambig==nothing, methods(foo)) == [true, false, false]
Base.delete_method(first(methods(foo)))
@test_throws MethodError foo(1, 1)
@test map(g->g.ambig==nothing, methods(foo)) == [false, false]

# multiple deletions and ambiguities
typeparam(::Type{T}, a::Array{T}) where T<:AbstractFloat = 1
typeparam(::Type{T}, a::Array{T}) where T = 2
for mth in collect(methods(typeparam))
Base.delete_method(mth)
end
typeparam(::Type{T}, a::AbstractArray{T}) where T<:AbstractFloat = 1
typeparam(::Type{T}, a::AbstractArray{T}) where T = 2
@test typeparam(Float64, rand(2)) == 1
@test typeparam(Int, rand(Int, 2)) == 2

end

0 comments on commit 3357f95

Please sign in to comment.