From 8b26eb6de828e9581d2a8cc7f9ac2730365438d0 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Fri, 10 Sep 2021 06:00:43 -0500 Subject: [PATCH] Allow constant-propagation to be disabled (#42125) Our heuristics for constant propagation are imperfect (and probably never will be perfect), and I've now seen many examples of methods that no developer would ask to have const-propped get that treatment. In some cases the cost for latency/precompilation is very large. This renames `@aggressive_constprop` to `@constprop` and allows two settings, `:aggressive` and `:none`. Closes #38983 Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Co-authored-by: Martin Holters --- base/char.jl | 24 +-- base/compiler/abstractinterpretation.jl | 8 +- base/expr.jl | 30 ++- src/ast.c | 3 +- src/dump.c | 4 +- src/ircode.c | 47 +++-- src/jltypes.c | 8 +- src/julia.h | 6 +- src/julia_internal.h | 17 +- src/method.c | 10 +- stdlib/Serialization/src/Serialization.jl | 12 +- test/compiler/inference.jl | 10 +- test/compiler/inline.jl | 230 ++++++++++++++++++++++ 13 files changed, 346 insertions(+), 63 deletions(-) diff --git a/base/char.jl b/base/char.jl index 0584471cb6a33..c8b1c28166bbf 100644 --- a/base/char.jl +++ b/base/char.jl @@ -45,10 +45,10 @@ represents a valid Unicode character. """ Char -@aggressive_constprop (::Type{T})(x::Number) where {T<:AbstractChar} = T(UInt32(x)) -@aggressive_constprop AbstractChar(x::Number) = Char(x) -@aggressive_constprop (::Type{T})(x::AbstractChar) where {T<:Union{Number,AbstractChar}} = T(codepoint(x)) -@aggressive_constprop (::Type{T})(x::AbstractChar) where {T<:Union{Int32,Int64}} = codepoint(x) % T +@constprop :aggressive (::Type{T})(x::Number) where {T<:AbstractChar} = T(UInt32(x)) +@constprop :aggressive AbstractChar(x::Number) = Char(x) +@constprop :aggressive (::Type{T})(x::AbstractChar) where {T<:Union{Number,AbstractChar}} = T(codepoint(x)) +@constprop :aggressive (::Type{T})(x::AbstractChar) where {T<:Union{Int32,Int64}} = codepoint(x) % T (::Type{T})(x::T) where {T<:AbstractChar} = x """ @@ -75,7 +75,7 @@ return a different-sized integer (e.g. `UInt8`). """ function codepoint end -@aggressive_constprop codepoint(c::Char) = UInt32(c) +@constprop :aggressive codepoint(c::Char) = UInt32(c) struct InvalidCharError{T<:AbstractChar} <: Exception char::T @@ -124,7 +124,7 @@ See also [`decode_overlong`](@ref) and [`show_invalid`](@ref). """ isoverlong(c::AbstractChar) = false -@aggressive_constprop function UInt32(c::Char) +@constprop :aggressive function UInt32(c::Char) # TODO: use optimized inline LLVM u = bitcast(UInt32, c) u < 0x80000000 && return u >> 24 @@ -148,7 +148,7 @@ that support overlong encodings should implement `Base.decode_overlong`. """ function decode_overlong end -@aggressive_constprop function decode_overlong(c::Char) +@constprop :aggressive function decode_overlong(c::Char) u = bitcast(UInt32, c) l1 = leading_ones(u) t0 = trailing_zeros(u) & 56 @@ -158,7 +158,7 @@ function decode_overlong end ((u & 0x007f0000) >> 4) | ((u & 0x7f000000) >> 6) end -@aggressive_constprop function Char(u::UInt32) +@constprop :aggressive function Char(u::UInt32) u < 0x80 && return bitcast(Char, u << 24) u < 0x00200000 || throw_code_point_err(u) c = ((u << 0) & 0x0000003f) | ((u << 2) & 0x00003f00) | @@ -169,14 +169,14 @@ end bitcast(Char, c) end -@aggressive_constprop @noinline UInt32_cold(c::Char) = UInt32(c) -@aggressive_constprop function (T::Union{Type{Int8},Type{UInt8}})(c::Char) +@constprop :aggressive @noinline UInt32_cold(c::Char) = UInt32(c) +@constprop :aggressive function (T::Union{Type{Int8},Type{UInt8}})(c::Char) i = bitcast(Int32, c) i ≥ 0 ? ((i >>> 24) % T) : T(UInt32_cold(c)) end -@aggressive_constprop @noinline Char_cold(b::UInt32) = Char(b) -@aggressive_constprop function Char(b::Union{Int8,UInt8}) +@constprop :aggressive @noinline Char_cold(b::UInt32) = Char(b) +@constprop :aggressive function Char(b::Union{Int8,UInt8}) 0 ≤ b ≤ 0x7f ? bitcast(Char, (b % UInt32) << 24) : Char_cold(UInt32(b)) end diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index a057a1879412c..41fe8f5034bcc 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -570,6 +570,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me sv::InferenceState) const_prop_entry_heuristic(interp, result, sv) || return nothing method = match.method + if method.constprop == 0x02 + add_remark!(interp, sv, "[constprop] Disabled by method parameter") + return nothing + end + force = force_const_prop(interp, f, method) + force || const_prop_entry_heuristic(interp, result, sv) || return nothing nargs::Int = method.nargs method.isva && (nargs -= 1) if length(argtypes) < nargs @@ -639,7 +645,7 @@ function is_allconst(argtypes::Vector{Any}) end function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method::Method) - return method.aggressive_constprop || + return method.constprop == 0x01 || InferenceParams(interp).aggressive_constant_propagation || istopfunction(f, :getproperty) || istopfunction(f, :setproperty!) diff --git a/base/expr.jl b/base/expr.jl index 2bc59717fea47..8382858249f9f 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -240,16 +240,26 @@ macro pure(ex) end """ - @aggressive_constprop ex - @aggressive_constprop(ex) - -`@aggressive_constprop` requests more aggressive interprocedural constant -propagation for the annotated function. For a method where the return type -depends on the value of the arguments, this can yield improved inference results -at the cost of additional compile time. -""" -macro aggressive_constprop(ex) - esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex) + @constprop setting ex + @constprop(setting, ex) + +`@constprop` controls the mode of interprocedural constant propagation for the +annotated function. Two `setting`s are supported: + +- `@constprop :aggressive ex`: apply constant propagation aggressively. + For a method where the return type depends on the value of the arguments, + this can yield improved inference results at the cost of additional compile time. +- `@constprop :none ex`: disable constant propagation. This can reduce compile + times for functions that Julia might otherwise deem worthy of constant-propagation. + Common cases are for functions with `Bool`- or `Symbol`-valued arguments or keyword arguments. +""" +macro constprop(setting, ex) + if isa(setting, QuoteNode) + setting = setting.value + end + setting === :aggressive && return esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex) + setting === :none && return esc(isa(ex, Expr) ? pushmeta!(ex, :no_constprop) : ex) + throw(ArgumentError("@constprop $setting not supported")) end """ diff --git a/src/ast.c b/src/ast.c index de2492db08c94..cd54c9fac712f 100644 --- a/src/ast.c +++ b/src/ast.c @@ -59,7 +59,7 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym; jl_sym_t *noinline_sym; jl_sym_t *generated_sym; jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym; jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym; -jl_sym_t *aggressive_constprop_sym; +jl_sym_t *aggressive_constprop_sym; jl_sym_t *no_constprop_sym; jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym; jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym; jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym; @@ -399,6 +399,7 @@ void jl_init_common_symbols(void) polly_sym = jl_symbol("polly"); propagate_inbounds_sym = jl_symbol("propagate_inbounds"); aggressive_constprop_sym = jl_symbol("aggressive_constprop"); + no_constprop_sym = jl_symbol("no_constprop"); isdefined_sym = jl_symbol("isdefined"); nospecialize_sym = jl_symbol("nospecialize"); specialize_sym = jl_symbol("specialize"); diff --git a/src/dump.c b/src/dump.c index f7a0ced4a6ab6..b749248dfac68 100644 --- a/src/dump.c +++ b/src/dump.c @@ -670,7 +670,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li write_int8(s->s, m->isva); write_int8(s->s, m->pure); write_int8(s->s, m->is_for_opaque_closure); - write_int8(s->s, m->aggressive_constprop); + write_int8(s->s, m->constprop); jl_serialize_value(s, (jl_value_t*)m->slot_syms); jl_serialize_value(s, (jl_value_t*)m->roots); jl_serialize_value(s, (jl_value_t*)m->ccallable); @@ -1524,7 +1524,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_ m->isva = read_int8(s->s); m->pure = read_int8(s->s); m->is_for_opaque_closure = read_int8(s->s); - m->aggressive_constprop = read_int8(s->s); + m->constprop = read_int8(s->s); m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms); jl_gc_wb(m, m->slot_syms); m->roots = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->roots); diff --git a/src/ircode.c b/src/ircode.c index 212febe121a75..414c80213d5a4 100644 --- a/src/ircode.c +++ b/src/ircode.c @@ -381,6 +381,17 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal) } } +static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop) +{ + jl_code_info_flags_t flags; + flags.bits.pure = pure; + flags.bits.propagate_inbounds = propagate_inbounds; + flags.bits.inlineable = inlineable; + flags.bits.inferred = inferred; + flags.bits.constprop = constprop; + return flags; +} + // --- decoding --- static jl_value_t *jl_decode_value(jl_ircode_state *s) JL_GC_DISABLED; @@ -702,12 +713,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code) jl_current_task->ptls }; - uint8_t flags = (code->aggressive_constprop << 4) - | (code->inferred << 3) - | (code->inlineable << 2) - | (code->propagate_inbounds << 1) - | (code->pure << 0); - write_uint8(s.s, flags); + jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inlineable, code->inferred, code->constprop); + write_uint8(s.s, flags.packed); size_t nslots = jl_array_len(code->slotflags); assert(nslots >= m->nargs && nslots < INT32_MAX); // required by generated functions @@ -787,12 +794,13 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t }; jl_code_info_t *code = jl_new_code_info_uninit(); - uint8_t flags = read_uint8(s.s); - code->aggressive_constprop = !!(flags & (1 << 4)); - code->inferred = !!(flags & (1 << 3)); - code->inlineable = !!(flags & (1 << 2)); - code->propagate_inbounds = !!(flags & (1 << 1)); - code->pure = !!(flags & (1 << 0)); + jl_code_info_flags_t flags; + flags.packed = read_uint8(s.s); + code->constprop = flags.bits.constprop; + code->inferred = flags.bits.inferred; + code->inlineable = flags.bits.inlineable; + code->propagate_inbounds = flags.bits.propagate_inbounds; + code->pure = flags.bits.pure; size_t nslots = read_int32(&src); code->slotflags = jl_alloc_array_1d(jl_array_uint8_type, nslots); @@ -847,8 +855,9 @@ JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data) if (jl_is_code_info(data)) return ((jl_code_info_t*)data)->inferred; assert(jl_typeis(data, jl_array_uint8_type)); - uint8_t flags = ((uint8_t*)data->data)[0]; - return !!(flags & (1 << 3)); + jl_code_info_flags_t flags; + flags.packed = ((uint8_t*)data->data)[0]; + return flags.bits.inferred; } JL_DLLEXPORT uint8_t jl_ir_flag_inlineable(jl_array_t *data) @@ -856,8 +865,9 @@ JL_DLLEXPORT uint8_t jl_ir_flag_inlineable(jl_array_t *data) if (jl_is_code_info(data)) return ((jl_code_info_t*)data)->inlineable; assert(jl_typeis(data, jl_array_uint8_type)); - uint8_t flags = ((uint8_t*)data->data)[0]; - return !!(flags & (1 << 2)); + jl_code_info_flags_t flags; + flags.packed = ((uint8_t*)data->data)[0]; + return flags.bits.inlineable; } JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data) @@ -865,8 +875,9 @@ JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data) if (jl_is_code_info(data)) return ((jl_code_info_t*)data)->pure; assert(jl_typeis(data, jl_array_uint8_type)); - uint8_t flags = ((uint8_t*)data->data)[0]; - return !!(flags & (1 << 0)); + jl_code_info_flags_t flags; + flags.packed = ((uint8_t*)data->data)[0]; + return flags.bits.pure; } JL_DLLEXPORT jl_value_t *jl_compress_argnames(jl_array_t *syms) diff --git a/src/jltypes.c b/src/jltypes.c index 43171ee332e87..1330650ee8892 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2348,7 +2348,7 @@ void jl_init_types(void) JL_GC_DISABLED "inlineable", "propagate_inbounds", "pure", - "aggressive_constprop"), + "constprop"), jl_svec(19, jl_array_any_type, jl_array_int32_type, @@ -2368,7 +2368,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_bool_type, jl_bool_type, jl_bool_type, - jl_bool_type), + jl_uint8_type), jl_emptysvec, 0, 1, 19); @@ -2401,7 +2401,7 @@ void jl_init_types(void) JL_GC_DISABLED "isva", "pure", "is_for_opaque_closure", - "aggressive_constprop"), + "constprop"), jl_svec(26, jl_symbol_type, jl_module_type, @@ -2428,7 +2428,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_bool_type, jl_bool_type, jl_bool_type, - jl_bool_type), + jl_uint8_type), jl_emptysvec, 0, 1, 10); diff --git a/src/julia.h b/src/julia.h index b2a8bd15bcb22..99e155b2b10eb 100644 --- a/src/julia.h +++ b/src/julia.h @@ -296,7 +296,8 @@ typedef struct _jl_code_info_t { uint8_t inlineable; uint8_t propagate_inbounds; uint8_t pure; - uint8_t aggressive_constprop; + // uint8 settings + uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none } jl_code_info_t; // This type describes a single method definition, and stores data @@ -344,7 +345,8 @@ typedef struct _jl_method_t { uint8_t isva; uint8_t pure; uint8_t is_for_opaque_closure; - uint8_t aggressive_constprop; + // uint8 settings + uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none // hidden fields: // lock for modifications to the method diff --git a/src/julia_internal.h b/src/julia_internal.h index 673d4459bef03..702f7586cf4c6 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -457,9 +457,24 @@ STATIC_INLINE jl_value_t *undefref_check(jl_datatype_t *dt, jl_value_t *v) JL_NO return v; } +// -- helper types -- // + +typedef struct { + uint8_t pure:1; + uint8_t propagate_inbounds:1; + uint8_t inlineable:1; + uint8_t inferred:1; + uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none +} jl_code_info_flags_bitfield_t; + +typedef union { + jl_code_info_flags_bitfield_t bits; + uint8_t packed; +} jl_code_info_flags_t; // -- functions -- // +// jl_code_info_flag_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inlineable, uint8_t inferred, uint8_t constprop); jl_code_info_t *jl_type_infer(jl_method_instance_t *li, size_t world, int force); jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *meth JL_PROPAGATES_ROOT, size_t world); jl_code_instance_t *jl_generate_fptr(jl_method_instance_t *mi JL_PROPAGATES_ROOT, size_t world); @@ -1358,7 +1373,7 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym; extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym; extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym; extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym; -extern jl_sym_t *aggressive_constprop_sym; +extern jl_sym_t *aggressive_constprop_sym; extern jl_sym_t *no_constprop_sym; extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym; extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym; extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym; diff --git a/src/method.c b/src/method.c index 22145a4349853..dd15f67920849 100644 --- a/src/method.c +++ b/src/method.c @@ -288,7 +288,9 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir) else if (ma == (jl_value_t*)propagate_inbounds_sym) li->propagate_inbounds = 1; else if (ma == (jl_value_t*)aggressive_constprop_sym) - li->aggressive_constprop = 1; + li->constprop = 1; + else if (ma == (jl_value_t*)no_constprop_sym) + li->constprop = 2; else jl_array_ptr_set(meta, ins++, ma); } @@ -379,7 +381,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void) src->propagate_inbounds = 0; src->pure = 0; src->edges = jl_nothing; - src->aggressive_constprop = 0; + src->constprop = 0; return src; } @@ -566,7 +568,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src) } m->called = called; m->pure = src->pure; - m->aggressive_constprop = src->aggressive_constprop; + m->constprop = src->constprop; jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name); jl_array_t *copy = NULL; @@ -682,7 +684,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module) m->primary_world = 1; m->deleted_world = ~(size_t)0; m->is_for_opaque_closure = 0; - m->aggressive_constprop = 0; + m->constprop = 0; JL_MUTEX_INIT(&m->writelock); return m; } diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index 592db96565c7a..110abcff18601 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -79,7 +79,7 @@ const TAGS = Any[ @assert length(TAGS) == 255 -const ser_version = 15 # do not make changes without bumping the version #! +const ser_version = 16 # do not make changes without bumping the version #! format_version(::AbstractSerializer) = ser_version format_version(s::Serializer) = s.version @@ -418,7 +418,7 @@ function serialize(s::AbstractSerializer, meth::Method) serialize(s, meth.nargs) serialize(s, meth.isva) serialize(s, meth.is_for_opaque_closure) - serialize(s, meth.aggressive_constprop) + serialize(s, meth.constprop) if isdefined(meth, :source) serialize(s, Base._uncompressed_ast(meth, meth.source)) else @@ -1014,12 +1014,12 @@ function deserialize(s::AbstractSerializer, ::Type{Method}) nargs = deserialize(s)::Int32 isva = deserialize(s)::Bool is_for_opaque_closure = false - aggressive_constprop = false + constprop = 0x00 template_or_is_opaque = deserialize(s) if isa(template_or_is_opaque, Bool) is_for_opaque_closure = template_or_is_opaque if format_version(s) >= 14 - aggressive_constprop = deserialize(s)::Bool + constprop = deserialize(s)::UInt8 end template = deserialize(s) else @@ -1039,7 +1039,7 @@ function deserialize(s::AbstractSerializer, ::Type{Method}) meth.nargs = nargs meth.isva = isva meth.is_for_opaque_closure = is_for_opaque_closure - meth.aggressive_constprop = aggressive_constprop + meth.constprop = constprop if template !== nothing # TODO: compress template meth.source = template::CodeInfo @@ -1163,7 +1163,7 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo}) ci.propagate_inbounds = deserialize(s) ci.pure = deserialize(s) if format_version(s) >= 14 - ci.aggressive_constprop = deserialize(s)::Bool + ci.constprop = deserialize(s)::UInt8 end return ci end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index d4d0f6700c179..03292aaea89d8 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3169,9 +3169,9 @@ g38888() = S38888(Base.inferencebarrier(3), nothing) f_inf_error_bottom(x::Vector) = isempty(x) ? error(x[1]) : x @test Core.Compiler.return_type(f_inf_error_bottom, Tuple{Vector{Any}}) == Vector{Any} -# @aggressive_constprop +# @constprop :aggressive @noinline g_nonaggressive(y, x) = Val{x}() -@noinline @Base.aggressive_constprop g_aggressive(y, x) = Val{x}() +@noinline Base.@constprop :aggressive g_aggressive(y, x) = Val{x}() f_nonaggressive(x) = g_nonaggressive(x, 1) f_aggressive(x) = g_aggressive(x, 1) @@ -3181,6 +3181,12 @@ f_aggressive(x) = g_aggressive(x, 1) @test Base.return_types(f_nonaggressive, Tuple{Int})[1] == Val @test Base.return_types(f_aggressive, Tuple{Int})[1] == Val{1} +# @constprop :none +@noinline Base.@constprop :none g_noaggressive(flag::Bool) = flag ? 1 : 1.0 +ftrue_noaggressive() = g_noaggressive(true) +@test only(Base.return_types(ftrue_noaggressive, Tuple{})) == Union{Int,Float64} + + function splat_lotta_unions() a = Union{Tuple{Int},Tuple{String,Vararg{Int}},Tuple{Int,Vararg{Int}}}[(2,)][1] b = Union{Int8,Int16,Int32,Int64,Int128}[1][1] diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index eb1c3ddd2a963..9218a484a3257 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -381,6 +381,236 @@ using Base.Experimental: @opaque f_oc_getfield(x) = (@opaque ()->x)() @test fully_eliminated(f_oc_getfield, Tuple{Int}) +# check if `x` is a statically-resolved call of a function whose name is `sym` +isinvoke(@nospecialize(x), sym::Symbol) = isinvoke(x, mi->mi.def.name===sym) +function isinvoke(@nospecialize(x), pred) + if Meta.isexpr(x, :invoke) + return pred(x.args[1]::Core.MethodInstance) + end + return false +end +code_typed1(args...; kwargs...) = (first(only(code_typed(args...; kwargs...)))::Core.CodeInfo).code + +@testset "@inline/@noinline annotation before definition" begin + M = Module() + @eval M begin + @inline function _def_inline(x) + # this call won't be resolved and thus will prevent inlining to happen if we don't + # annotate `@inline` at the top of this function body + return unresolved_call(x) + end + def_inline(x) = _def_inline(x) + @noinline _def_noinline(x) = x # obviously will be inlined otherwise + def_noinline(x) = _def_noinline(x) + + # test that they don't conflict with other "before-definition" macros + @inline Base.@constprop :aggressive function _def_inline_noconflict(x) + # this call won't be resolved and thus will prevent inlining to happen if we don't + # annotate `@inline` at the top of this function body + return unresolved_call(x) + end + def_inline_noconflict(x) = _def_inline_noconflict(x) + @noinline Base.@constprop :aggressive _def_noinline_noconflict(x) = x # obviously will be inlined otherwise + def_noinline_noconflict(x) = _def_noinline_noconflict(x) + end + + let code = code_typed1(M.def_inline, (Int,)) + @test all(code) do x + !isinvoke(x, :_def_inline) + end + end + let code = code_typed1(M.def_noinline, (Int,)) + @test any(code) do x + isinvoke(x, :_def_noinline) + end + end + # test that they don't conflict with other "before-definition" macros + let code = code_typed1(M.def_inline_noconflict, (Int,)) + @test all(code) do x + !isinvoke(x, :_def_inline_noconflict) + end + end + let code = code_typed1(M.def_noinline_noconflict, (Int,)) + @test any(code) do x + isinvoke(x, :_def_noinline_noconflict) + end + end +end + +@testset "@inline/@noinline annotation within a function body" begin + M = Module() + @eval M begin + function _body_inline(x) + @inline + # this call won't be resolved and thus will prevent inlining to happen if we don't + # annotate `@inline` at the top of this function body + return unresolved_call(x) + end + body_inline(x) = _body_inline(x) + function _body_noinline(x) + @noinline + return x # obviously will be inlined otherwise + end + body_noinline(x) = _body_noinline(x) + + # test annotations for `do` blocks + @inline simple_caller(a) = a() + function do_inline(x) + simple_caller() do + @inline + # this call won't be resolved and thus will prevent inlining to happen if we don't + # annotate `@inline` at the top of this anonymous function body + return unresolved_call(x) + end + end + function do_noinline(x) + simple_caller() do + @noinline + return x # obviously will be inlined otherwise + end + end + end + + let code = code_typed1(M.body_inline, (Int,)) + @test all(code) do x + !isinvoke(x, :_body_inline) + end + end + let code = code_typed1(M.body_noinline, (Int,)) + @test any(code) do x + isinvoke(x, :_body_noinline) + end + end + # test annotations for `do` blocks + let code = code_typed1(M.do_inline, (Int,)) + # what we test here is that both `simple_caller` and the anonymous function that the + # `do` block creates should inlined away, and as a result there is only the unresolved call + @test all(code) do x + !isinvoke(x, :simple_caller) && + !isinvoke(x, mi->startswith(string(mi.def.name), '#')) + end + end + let code = code_typed1(M.do_noinline, (Int,)) + # the anonymous function that the `do` block created shouldn't be inlined here + @test any(code) do x + isinvoke(x, mi->startswith(string(mi.def.name), '#')) + end + end +end + +@testset "callsite @inline/@noinline annotations" begin + M = Module() + @eval M begin + # this global variable prevents inference to fold everything as constant, and/or the optimizer to inline the call accessing to this + g = 0 + + @noinline noinlined_explicit(x) = x + force_inline_explicit(x) = @inline noinlined_explicit(x) + force_inline_block_explicit(x) = @inline noinlined_explicit(x) + noinlined_explicit(x) + noinlined_implicit(x) = g + force_inline_implicit(x) = @inline noinlined_implicit(x) + force_inline_block_implicit(x) = @inline noinlined_implicit(x) + noinlined_implicit(x) + + @inline inlined_explicit(x) = x + force_noinline_explicit(x) = @noinline inlined_explicit(x) + force_noinline_block_explicit(x) = @noinline inlined_explicit(x) + inlined_explicit(x) + inlined_implicit(x) = x + force_noinline_implicit(x) = @noinline inlined_implicit(x) + force_noinline_block_implicit(x) = @noinline inlined_implicit(x) + inlined_implicit(x) + + # test callsite annotations for constant-prop'ed calls + + @noinline Base.@constprop :aggressive noinlined_constprop_explicit(a) = a+g + force_inline_constprop_explicit() = @inline noinlined_constprop_explicit(0) + Base.@constprop :aggressive noinlined_constprop_implicit(a) = a+g + force_inline_constprop_implicit() = @inline noinlined_constprop_implicit(0) + + @inline Base.@constprop :aggressive inlined_constprop_explicit(a) = a+g + force_noinline_constprop_explicit() = @noinline inlined_constprop_explicit(0) + @inline Base.@constprop :aggressive inlined_constprop_implicit(a) = a+g + force_noinline_constprop_implicit() = @noinline inlined_constprop_implicit(0) + + @noinline notinlined(a) = a + function nested(a0, b0) + @noinline begin + a = @inline notinlined(a0) # this call should be inlined + b = notinlined(b0) # this call should NOT be inlined + return a, b + end + end + end + + let code = code_typed1(M.force_inline_explicit, (Int,)) + @test all(x->!isinvoke(x, :noinlined_explicit), code) + end + let code = code_typed1(M.force_inline_block_explicit, (Int,)) + @test all(code) do x + !isinvoke(x, :noinlined_explicit) && + !isinvoke(x, :(+)) + end + end + let code = code_typed1(M.force_inline_implicit, (Int,)) + @test all(x->!isinvoke(x, :noinlined_implicit), code) + end + let code = code_typed1(M.force_inline_block_implicit, (Int,)) + @test all(x->!isinvoke(x, :noinlined_explicit), code) + end + + let code = code_typed1(M.force_noinline_explicit, (Int,)) + @test any(x->isinvoke(x, :inlined_explicit), code) + end + let code = code_typed1(M.force_noinline_block_explicit, (Int,)) + @test count(x->isinvoke(x, :inlined_explicit), code) == 2 + end + let code = code_typed1(M.force_noinline_implicit, (Int,)) + @test any(x->isinvoke(x, :inlined_implicit), code) + end + let code = code_typed1(M.force_noinline_block_implicit, (Int,)) + @test count(x->isinvoke(x, :inlined_implicit), code) == 2 + end + + let code = code_typed1(M.force_inline_constprop_explicit) + @test all(x->!isinvoke(x, :noinlined_constprop_explicit), code) + end + let code = code_typed1(M.force_inline_constprop_implicit) + @test all(x->!isinvoke(x, :noinlined_constprop_implicit), code) + end + + let code = code_typed1(M.force_noinline_constprop_explicit) + @test any(x->isinvoke(x, :inlined_constprop_explicit), code) + end + let code = code_typed1(M.force_noinline_constprop_implicit) + @test any(x->isinvoke(x, :inlined_constprop_implicit), code) + end + + let code = code_typed1(M.nested, (Int,Int)) + @test count(x->isinvoke(x, :notinlined), code) == 1 + end +end + +# force constant-prop' for `setproperty!` +# https://github.com/JuliaLang/julia/pull/41882 +let code = @eval Module() begin + # if we don't force constant-prop', `T = fieldtype(Foo, ::Symbol)` will be union-split to + # `Union{Type{Any},Type{Int}` and it will make `convert(T, nothing)` too costly + # and it leads to inlining failure + mutable struct Foo + val + _::Int + end + + function setter(xs) + for x in xs + x.val = nothing + end + end + + $code_typed1(setter, (Vector{Foo},)) + end + + @test !any(x->isinvoke(x, :setproperty!), code) +end + # Issue #41299 - inlining deletes error check in :> g41299(f::Tf, args::Vararg{Any,N}) where {Tf,N} = f(args...) @test_throws TypeError g41299(>:, 1, 2)