diff --git a/base/reflection.jl b/base/reflection.jl index 63165a62226db1..d40e4ac0dabed7 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -1083,6 +1083,22 @@ function isambiguous(m1::Method, m2::Method; ambiguous_bottom::Bool=false) return true end +""" + delete_method(m::Method) + +Make method `m` uncallable and force recompilation of any methods that use(d) it. +""" +function delete_method(m::Method) + ccall(:jl_method_table_disable, Void, (Any, Any), MethodTable(m), m) +end + +MethodTable(m::Method) = get_methodtable(m.sig) + +get_methodtable(u::UnionAll) = get_methodtable(u.body) +get_methodtable(sig) = _get_methodtable(sig.parameters[1]) +_get_methodtable(u::UnionAll) = _get_methodtable(u.body) +_get_methodtable(f) = f.name.mt + """ has_bottom_parameter(t) -> Bool diff --git a/src/gf.c b/src/gf.c index 5c1fcfef0525ea..fcab86e85ff9df 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1131,6 +1131,8 @@ static int check_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_ closure->after = 1; return 1; } + if (oldentry->max_world < ~(size_t)0) + return 1; union jl_typemap_t map = closure->defs; jl_tupletype_t *type = (jl_tupletype_t*)closure->match.type; jl_method_t *m = closure->newentry->func.method; @@ -1212,7 +1214,7 @@ static int check_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_ return 1; } -static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_entry_t *newentry) +static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_entry_t *newentry, jl_typemap_intersection_visitor_fptr fptr) { jl_tupletype_t *type = newentry->sig; jl_tupletype_t *ttypes = (jl_tupletype_t*)jl_unwrap_unionall((jl_value_t*)type); @@ -1226,7 +1228,7 @@ static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_e va = NULL; } struct ambiguous_matches_env env; - env.match.fptr = check_ambiguous_visitor; + env.match.fptr = fptr; env.match.type = (jl_value_t*)type; env.match.va = va; env.match.ti = NULL; @@ -1241,6 +1243,47 @@ static jl_value_t *check_ambiguous_matches(union jl_typemap_t defs, jl_typemap_e return env.shadowed; } +static int check_disabled_ambiguous_visitor(jl_typemap_entry_t *oldentry, struct typemap_intersection_env *closure0) +{ + struct ambiguous_matches_env *closure = container_of(closure0, struct ambiguous_matches_env, match); + if (oldentry == closure->newentry) { + closure->after = 1; + return 1; + } + if (!closure->after || oldentry->max_world < ~(size_t)0) // the second condition prevents us from confusion in multiple cycles of add/delete + return 1; + jl_tupletype_t *sig = oldentry->sig; + jl_value_t *isect = closure->match.ti; + if (closure->shadowed == NULL) + closure->shadowed = (jl_value_t*)jl_alloc_vec_any(0); + + int i, l = jl_array_len(closure->shadowed); + for (i = 0; i < l; i++) { + jl_method_t *mth = (jl_method_t*)jl_array_ptr_ref(closure->shadowed, i); + jl_value_t *isect2 = jl_type_intersection(mth->sig, (jl_value_t*)sig); + // see if the intersection was covered by precisely the disabled method + // that means we now need to record the ambiguity + if (jl_types_equal(isect, isect2)) { + jl_method_t *mambig = mth; + jl_method_t *m = oldentry->func.method; + if (m->ambig == jl_nothing) { + m->ambig = (jl_value_t*) jl_alloc_vec_any(0); + jl_gc_wb(m, m->ambig); + } + if (mambig->ambig == jl_nothing) { + mambig->ambig = (jl_value_t*) jl_alloc_vec_any(0); + jl_gc_wb(mambig, mambig->ambig); + } + jl_array_ptr_1d_push((jl_array_t*) m->ambig, (jl_value_t*) mambig); + jl_array_ptr_1d_push((jl_array_t*) mambig->ambig, (jl_value_t*) m); + } + } + + jl_array_ptr_1d_push((jl_array_t*)closure->shadowed, oldentry->func.value); + return 1; +} + + static void method_overwrite(jl_typemap_entry_t *newentry, jl_method_t *oldvalue) { // method overwritten @@ -1405,6 +1448,34 @@ void jl_method_instance_delete(jl_method_instance_t *mi) jl_uv_puts(JL_STDOUT, "<<<\n", 4); } +static int typemap_search(jl_typemap_entry_t *entry, void *closure) +{ + if ((void*)(entry->func.method) == *(jl_method_t**)closure) { + *(jl_typemap_entry_t**)closure = entry; + return 0; + } + return 1; +} + +//JL_DLLEXPORT void jl_method_table_disable(jl_methtable_t *mt, jl_typemap_entry_t *methodentry) +JL_DLLEXPORT void jl_method_table_disable(jl_methtable_t *mt, jl_method_t *method) +{ + struct invalidate_conflicting_env env; + env.invalidated = 0; + jl_typemap_entry_t *methodentry = (jl_typemap_entry_t*)(method); + if (jl_typemap_visitor(mt->defs, typemap_search, &methodentry)) + jl_error("method not in method table"); + JL_LOCK(&mt->writelock); + // Narrow the world age on the method to make it uncallable + methodentry->max_world = jl_world_counter++; + // Recompute ambiguities (deleting a more specific method might reveal ambiguities that it previously resolved) + env.max_world = methodentry->max_world; + check_ambiguous_matches(mt->defs, methodentry, check_disabled_ambiguous_visitor); // TODO: decrease repeated work? + // Invalidate the backedges + jl_typemap_visitor(methodentry->func.method->specializations, (jl_typemap_visitor_fptr)invalidate_backedges, &env); + JL_UNLOCK(&mt->writelock); +} + JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method, jl_tupletype_t *simpletype) { assert(jl_is_method(method)); @@ -1430,7 +1501,7 @@ JL_DLLEXPORT void jl_method_table_insert(jl_methtable_t *mt, jl_method_t *method method_overwrite(newentry, (jl_method_t*)oldvalue); } else { - oldvalue = check_ambiguous_matches(mt->defs, newentry); + oldvalue = check_ambiguous_matches(mt->defs, newentry, check_ambiguous_visitor); if (mt->backedges) { jl_value_t **backedges = (jl_value_t**)jl_array_data(mt->backedges); size_t i, na = jl_array_len(mt->backedges); diff --git a/test/reflection.jl b/test/reflection.jl index 0af98d038865d0..74d1539cd6074c 100644 --- a/test/reflection.jl +++ b/test/reflection.jl @@ -771,3 +771,129 @@ cinfo = cinfos[] test_similar_codeinfo(cinfo, cinfo_generated) @test_throws ErrorException code_lowered(f22979, typeof.(x22979), false) + +module MethodDeletion +using Test + +# Deletion after compiling top-level call +bar1(x) = 1 +bar1(x::Int) = 2 +foo1(x) = bar1(x) +faz1(x) = foo1(x) +@test faz1(1) == 2 +@test faz1(1.0) == 1 +m = first(methods(bar1, Tuple{Int})) +Base.delete_method(m) +@test bar1(1) == 1 +@test bar1(1.0) == 1 +@test foo1(1) == 1 +@test foo1(1.0) == 1 +@test faz1(1) == 1 +@test faz1(1.0) == 1 + +# Deletion after compiling middle-level call +bar2(x) = 1 +bar2(x::Int) = 2 +foo2(x) = bar2(x) +faz2(x) = foo2(x) +@test foo2(1) == 2 +@test foo2(1.0) == 1 +m = first(methods(bar2, Tuple{Int})) +Base.delete_method(m) +@test bar2(1.0) == 1 +@test bar2(1) == 1 +@test foo2(1) == 1 +@test foo2(1.0) == 1 +@test faz2(1) == 1 +@test faz2(1.0) == 1 + +# Deletion after compiling low-level call +bar3(x) = 1 +bar3(x::Int) = 2 +foo3(x) = bar3(x) +faz3(x) = foo3(x) +@test bar3(1) == 2 +@test bar3(1.0) == 1 +m = first(methods(bar3, Tuple{Int})) +Base.delete_method(m) +@test bar3(1) == 1 +@test bar3(1.0) == 1 +@test foo3(1) == 1 +@test foo3(1.0) == 1 +@test faz3(1) == 1 +@test faz3(1.0) == 1 + +# Deletion before any compilation +bar4(x) = 1 +bar4(x::Int) = 2 +foo4(x) = bar4(x) +faz4(x) = foo4(x) +m = first(methods(bar4, Tuple{Int})) +Base.delete_method(m) +@test bar4(1) == 1 +@test bar4(1.0) == 1 +@test foo4(1) == 1 +@test foo4(1.0) == 1 +@test faz4(1) == 1 +@test faz4(1.0) == 1 + +# Methods with keyword arguments +fookw(x; direction=:up) = direction +fookw(y::Int) = 2 +@test fookw("string") == :up +@test fookw(1) == 2 +m = collect(methods(fookw))[2] +Base.delete_method(m) +@test fookw(1) == 2 +@test_throws MethodError fookw("string") + +# functions with many methods +types = (Float64, Int32, String) +for T1 in types, T2 in types, T3 in types + @eval foomany(x::$T1, y::$T2, z::$T3) = y +end +@test foomany(Int32(5), "hello", 3.2) == "hello" +m = first(methods(foomany, Tuple{Int32, String, Float64})) +Base.delete_method(m) +@test_throws MethodError foomany(Int32(5), "hello", 3.2) + +struct EmptyType end +Base.convert(::Type{EmptyType}, x::Integer) = EmptyType() +m = first(methods(convert, Tuple{Type{EmptyType}, Integer})) +Base.delete_method(m) +@test_throws MethodError convert(EmptyType, 1) + +# parametric methods +parametric(A::Array{T,N}, i::Vararg{Int,N}) where {T,N} = N +@test parametric(rand(2,2), 1, 1) == 2 +m = first(methods(parametric)) +Base.delete_method(m) +@test_throws MethodError parametric(rand(2,2), 1, 1) + +# Deletion and ambiguity detection +foo(::Int, ::Int) = 1 +foo(::Real, ::Int) = 2 +foo(::Int, ::Real) = 3 +@test all(map(g->g.ambig==nothing, methods(foo))) +Base.delete_method(first(methods(foo))) +@test !all(map(g->g.ambig==nothing, methods(foo))) +@test_throws MethodError foo(1, 1) +foo(::Int, ::Int) = 1 +foo(1, 1) +@test map(g->g.ambig==nothing, methods(foo)) == [true, false, false] +Base.delete_method(first(methods(foo))) +@test_throws MethodError foo(1, 1) +@test map(g->g.ambig==nothing, methods(foo)) == [false, false] + +# multiple deletions and ambiguities +typeparam(::Type{T}, a::Array{T}) where T<:AbstractFloat = 1 +typeparam(::Type{T}, a::Array{T}) where T = 2 +for mth in collect(methods(typeparam)) + Base.delete_method(mth) +end +typeparam(::Type{T}, a::AbstractArray{T}) where T<:AbstractFloat = 1 +typeparam(::Type{T}, a::AbstractArray{T}) where T = 2 +@test typeparam(Float64, rand(2)) == 1 +@test typeparam(Int, rand(Int, 2)) == 2 + +end