Skip to content

Commit

Permalink
introduce @noinfer macro to tell the compiler to avoid excess infer…
Browse files Browse the repository at this point in the history
…ence

This commit introduces new compiler annotation named `@noinfer`, which
requests the compiler to avoid excess inference.

In order to discuss `@noinfer`, it would help a lot to understand the
behavior of `@nospecialize`.

Its docstring says simply:
> This is only a hint for the compiler to avoid excess code generation.

More specifically, it works by _suppressing dispatches_ with complex
runtime types of the annotated arguments. This could be understood with
the example below:
```julia
julia> function invokef(f, itr)
           local r = 0
           r += f(itr[1])
           r += f(itr[2])
           r += f(itr[3])
           r
       end;

julia> _isa = isa; # just for the sake of explanation, global variable to prevent inling
julia> f(a) = _isa(a, Function);
julia> g(@nospecialize a) = _isa(a, Function);
julia> dispatchonly = Any[sin, muladd, nothing]; # untyped container can cause excessive runtime dispatch

julia> @code_typed invokef(f, dispatchonly)
CodeInfo(
1 ─ %1  = π (0, Int64)
│   %2  = Base.arrayref(true, itr, 1)::Any
│   %3  = (f)(%2)::Any
│   %4  = (%1 + %3)::Any
│   %5  = Base.arrayref(true, itr, 2)::Any
│   %6  = (f)(%5)::Any
│   %7  = (%4 + %6)::Any
│   %8  = Base.arrayref(true, itr, 3)::Any
│   %9  = (f)(%8)::Any
│   %10 = (%7 + %9)::Any
└──       return %10
) => Any

julia> @code_typed invokef(g, dispatchonly)
CodeInfo(
1 ─ %1  = π (0, Int64)
│   %2  = Base.arrayref(true, itr, 1)::Any
│   %3  = invoke f(%2::Any)::Any
│   %4  = (%1 + %3)::Any
│   %5  = Base.arrayref(true, itr, 2)::Any
│   %6  = invoke f(%5::Any)::Any
│   %7  = (%4 + %6)::Any
│   %8  = Base.arrayref(true, itr, 3)::Any
│   %9  = invoke f(%8::Any)::Any
│   %10 = (%7 + %9)::Any
└──       return %10
) => Any
```

The calls of `f` remain to be `:call` expression (thus dispatched and
compiled at runtime) while the calls of `g` are resolved as `:invoke`
expressions. This is because `@nospecialize` requests the compiler to
give up compiling `g` with concrete argument types but with precisely
declared argument types, and in this way `invokef(g, dispatchonly)` will
avoid runtime dispatches and accompanying JIT compilations (i.e. "excess
code generation").

The problem here is, it influences dispatch only, does not intervene
into inference in anyway. So there is still a possibility of "excess
inference" when the compiler sees a considerable complexity of argument
types  during inference:
```julia
julia> withinfernce = tuple(sin, muladd, "foo"); # typed container can cause excessive inference

julia> @time @code_typed invokef(f, withinfernce);
  0.000812 seconds (3.77 k allocations: 217.938 KiB, 94.34% compilation time)

julia> @time @code_typed invokef(g, withinfernce);
  0.000753 seconds (3.77 k allocations: 218.047 KiB, 92.42% compilation time)
```

The purpose of this PR is basically to provide a more drastic way to
avoid excess compilation.

Here are some ideas to implement the functionality:
1. make `@nospecialize` avoid inference also
2. add noinfer effect when `@nospecialize`d method is annotated as `@noinline` also
3. implement as `@pure`-like boolean annotation to request noinfer effect on top of `@nospecialize`
4. implement as annotation that is orthogonal to `@nospecialize`

After trying 1 ~ 3., I decided to submit 3. for now, because I think the
interface is ready to be experimented.

This is almost same as what Jameson has done at <vtjnash@8ab7b6b>.
It turned out that this approach performs very badly because some of
`@nospecialize`'d arguments still need inference to perform reasonably.
For example, it's obvious that the following definition of
`getindex(@nospecialize(t::Tuple), i::Int)` would perform very badly if
`@nospecialize` blocks inference, because of a lack of useful type
information for succeeding optimizations:
<https://github.com/JuliaLang/julia/blob/12d364e8249a07097a233ce7ea2886002459cc50/base/tuple.jl#L29-L30>

The important observation is that we often use `@nospecialize` even when
we expect inference to forward type and constant information.
Adversely, we may be able to exploit the fact that we usually don't
expect inference to forward information to a callee when we annotate it
as `@noinline`.
So the idea is to enable the inference suppression when `@nospecialize`'d
method is annotated as `@noinline` also.

It's a reasonable choice, and could be implemented efficiently after <#41922>.
But it sounds a bit weird to me to associate no infer effect with
`@noinline`, and I also think there may be some cases we want to inline
a method while _partially_ avoiding inference, e.g.:
```julia
@noinline function twof(@nospecialize(f), n) # we really want not to
inline this method body ?
    if occursin('+', string(typeof(f).name.name::Symbol))
        2 + n
    elseif occursin('*', string(typeof(f).name.name::Symbol))
        2n
    else
        zero(n)
    end
end
```

So this is what this commit implements. It basically replaces the previous
`@noinline` flag with newly-introduced annotation named `@noinfer`. It's
still associated with `@nospecialize` and it only has effect when used
together with `@nospecialize`, but now it's not associated to `@noinline`
at least, and it would help us reason about the behavior of `@noinfer`
and experiment its effect more reliably:
```julia
Base.@noinfer function twof(@nospecialize(f), n) # the compiler may or not inline this method
    if occursin('+', string(typeof(f).name.name::Symbol))
        2 + n
    elseif occursin('*', string(typeof(f).name.name::Symbol))
        2n
    else
        zero(n)
    end
end
```

Actually, we can have `@nospecialize` and `@noinfer` separately, and it
would allow us to configure compilation strategies in a more
fine-grained way.
```julia
function noinfspec(Base.@noinfer(f), @nospecialize(g))
    ...
end
```

I'm fine with this approach, if initial experiments show `@noinfer` is
useful.
  • Loading branch information
aviatesk committed Sep 20, 2021
1 parent 1843201 commit 02c46d7
Show file tree
Hide file tree
Showing 15 changed files with 250 additions and 32 deletions.
9 changes: 8 additions & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
add_remark!(interp, sv, "Refusing to infer into `depwarn`")
return MethodCallResult(Any, false, false, nothing)
end
if is_noinfer(method)
sig = get_nospecialize_sig(method, sig, sparams)
end
topmost = nothing
# Limit argument type tuple growth of functions:
# look through the parents list to see if there's a call to the same method
Expand Down Expand Up @@ -593,7 +596,11 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
end
end
force |= allconst
mi = specialize_method(match; preexisting=!force)
if is_noinfer(method)
mi = specialize_method_noinfer(match; preexisting=!force)
else
mi = specialize_method(match; preexisting=!force)
end
if mi === nothing
add_remark!(interp, sv, "[constprop] Failed to specialize")
return nothing
Expand Down
10 changes: 7 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ end

function analyze_method!(match::MethodMatch, atypes::Vector{Any},
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
method = match.method
(; method, sparams) = match
methsig = method.sig

# Check that we habe the correct number of arguments
Expand All @@ -818,7 +818,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
end

# Bail out if any static parameters are left as TypeVar
validate_sparams(match.sparams) || return nothing
validate_sparams(sparams) || return nothing

et = state.et

Expand All @@ -827,7 +827,11 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
end

# See if there exists a specialization for this method signature
mi = specialize_method(match; preexisting=true) # Union{Nothing, MethodInstance}
if is_noinfer(method)
mi = specialize_method_noinfer(match; preexisting=true)
else
mi = specialize_method(match; preexisting=true)
end
if !isa(mi, MethodInstance)
return compileable_specialization(et, match)
end
Expand Down
26 changes: 25 additions & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ function is_inlineable_constant(@nospecialize(x))
return count_const_size(x) <= MAX_INLINE_CONST_SIZE
end

is_nospecialized(method::Method) = method.nospecialize 0

is_noinfer(method::Method) = method.noinfer && is_nospecialized(method)
# is_noinfer(method::Method) = is_nospecialized(method) && is_declared_noinline(method)

###########################
# MethodInstance/CodeInfo #
###########################
Expand Down Expand Up @@ -144,6 +149,20 @@ function get_compileable_sig(method::Method, @nospecialize(atypes), sparams::Sim
isa(atypes, DataType) || return nothing
mt = ccall(:jl_method_table_for, Any, (Any,), atypes)
mt === nothing && return nothing
atypes′ = ccall(:jl_normalize_to_compilable_sig, Any, (Any, Any, Any, Any),
mt, atypes, sparams, method)
is_compileable = isdispatchtuple(atypes) ||
ccall(:jl_isa_compileable_sig, Int32, (Any, Any), atypes′, method) 0
return is_compileable ? atypes′ : nothing
end

function get_nospecialize_sig(method::Method, @nospecialize(atypes), sparams::SimpleVector)
if isa(atypes, UnionAll)
atypes, sparams = normalize_typevars(method, atypes, sparams)
end
isa(atypes, DataType) || return method.sig
mt = ccall(:jl_method_table_for, Any, (Any,), atypes)
mt === nothing && return method.sig
return ccall(:jl_normalize_to_compilable_sig, Any, (Any, Any, Any, Any),
mt, atypes, sparams, method)
end
Expand Down Expand Up @@ -196,7 +215,7 @@ function specialize_method(method::Method, @nospecialize(atypes), sparams::Simpl
if preexisting
# check cached specializations
# for an existing result stored there
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)::Union{Nothing,MethodInstance}
return ccall(:jl_specializations_lookup, Ref{MethodInstance}, (Any, Any), method, atypes)
end
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), method, atypes, sparams)
end
Expand All @@ -205,6 +224,11 @@ function specialize_method(match::MethodMatch; kwargs...)
return specialize_method(match.method, match.spec_types, match.sparams; kwargs...)
end

function specialize_method_noinfer((; method, spec_types, sparams)::MethodMatch; kwargs...)
atypes = get_nospecialize_sig(method, spec_types, sparams)
return specialize_method(method, atypes, sparams; kwargs...)
end

# This function is used for computing alternate limit heuristics
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
Expand Down
43 changes: 38 additions & 5 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ end
@nospecialize
Applied to a function argument name, hints to the compiler that the method
should not be specialized for different types of that argument,
but instead to use precisely the declared type for each argument.
This is only a hint for avoiding excess code generation.
Can be applied to an argument within a formal argument list,
implementation should not be specialized for different types of that argument,
but instead use the declared type for that argument.
It can be applied to an argument within a formal argument list,
or in the function body.
When applied to an argument, the macro must wrap the entire argument expression.
When applied to an argument, the macro must wrap the entire argument expression, e.g.,
`@nospecialize(x::Real)` or `@nospecialize(i::Integer...)` rather than wrapping just the argument name.
When used in a function body, the macro must occur in statement position and
before any code.
Expand Down Expand Up @@ -87,6 +87,39 @@ end
f(y) = [x for x in y]
@specialize
```
!!! note
`@nospecialize` affects code generation but not inference: it limits the diversity
of the resulting native code, but it does not impose any limitations (beyond the
standard ones) on type-inference. Use [`Base.@noinfer`](@ref) together with
`@nospecialize` to additionally suppress inference.
# Example
```julia
julia> f(A::AbstractArray) = g(A)
f (generic function with 1 method)
julia> @noinline g(@nospecialize(A::AbstractArray)) = A[1]
g (generic function with 1 method)
julia> @code_typed f([1.0])
CodeInfo(
1 ─ %1 = invoke Main.g(_2::AbstractArray)::Float64
└── return %1
) => Float64
```
Here, the `@nospecialize` annotation results in the equivalent of
```julia
f(A::AbstractArray) = invoke(g, Tuple{AbstractArray}, A)
```
ensuring that only one version of native code will be generated for `g`,
one that is generic for any `AbstractArray`.
However, the specific return type is still inferred for both `g` and `f`,
and this is still used in optimizing the callers of `f` and `g`.
"""
macro nospecialize(vars...)
if nfields(vars) === 1
Expand Down
51 changes: 43 additions & 8 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,12 @@ macro noinline(x)
end

"""
@pure ex
@pure(ex)
Base.@pure function f(args...)
...
end
Base.@pure f(args...) = ...
`@pure` gives the compiler a hint for the definition of a pure function,
`Base.@pure` gives the compiler a hint for the definition of a pure function,
helping for type inference.
This macro is intended for internal compiler use and may be subject to changes.
Expand All @@ -349,16 +351,16 @@ macro pure(ex)
end

"""
@constprop setting ex
@constprop(setting, ex)
Base.@constprop setting ex
Base.@constprop(setting, ex)
`@constprop` controls the mode of interprocedural constant propagation for the
`Base.@constprop` controls the mode of interprocedural constant propagation for the
annotated function. Two `setting`s are supported:
- `@constprop :aggressive ex`: apply constant propagation aggressively.
- `Base.@constprop :aggressive ex`: apply constant propagation aggressively.
For a method where the return type depends on the value of the arguments,
this can yield improved inference results at the cost of additional compile time.
- `@constprop :none ex`: disable constant propagation. This can reduce compile
- `Base.@constprop :none ex`: disable constant propagation. This can reduce compile
times for functions that Julia might otherwise deem worthy of constant-propagation.
Common cases are for functions with `Bool`- or `Symbol`-valued arguments or keyword arguments.
"""
Expand All @@ -371,6 +373,39 @@ macro constprop(setting, ex)
throw(ArgumentError("@constprop $setting not supported"))
end

"""
Base.@noinfer function f(args...)
@nospecialize ...
...
end
Base.@noinfer f(@nospecialize args...) = ...
Tells the compiler to infer `f` using the declared types of `@nospecialize`d arguments.
This can be used to limit the number of compiler-generated specializations during inference.
# Example
```julia
julia> f(A::AbstractArray) = g(A)
f (generic function with 1 method)
julia> @noinline Base.@noinfer g(@nospecialize(A::AbstractArray)) = A[1]
g (generic function with 1 method)
julia> @code_typed f([1.0])
CodeInfo(
1 ─ %1 = invoke Main.g(_2::AbstractArray)::Any
└── return %1
) => Any
```
In this example, `f` will be inferred for each specific type of `A`,
but `g` will only be inferred once.
"""
macro noinfer(ex)
esc(isa(ex, Expr) ? pushmeta!(ex, :noinfer) : ex)
end

"""
@propagate_inbounds
Expand Down
2 changes: 2 additions & 0 deletions doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ Base.@inline
Base.@noinline
Base.@nospecialize
Base.@specialize
Base.@noinfer
Base.@constprop
Base.gensym
Base.@gensym
var"name"
Expand Down
4 changes: 3 additions & 1 deletion src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym;
jl_sym_t *noinline_sym; jl_sym_t *generated_sym;
jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym;
jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym;
jl_sym_t *nospecialize_sym; jl_sym_t *noinfer_sym;
jl_sym_t *aggressive_constprop_sym; jl_sym_t *no_constprop_sym;
jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym;
jl_sym_t *macrocall_sym;
jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym;
jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym;
jl_sym_t *gc_preserve_begin_sym; jl_sym_t *gc_preserve_end_sym;
Expand Down Expand Up @@ -403,6 +404,7 @@ void jl_init_common_symbols(void)
isdefined_sym = jl_symbol("isdefined");
nospecialize_sym = jl_symbol("nospecialize");
specialize_sym = jl_symbol("specialize");
noinfer_sym = jl_symbol("noinfer");
optlevel_sym = jl_symbol("optlevel");
compile_sym = jl_symbol("compile");
infer_sym = jl_symbol("infer");
Expand Down
2 changes: 2 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
write_int8(s->s, m->isva);
write_int8(s->s, m->pure);
write_int8(s->s, m->is_for_opaque_closure);
write_int8(s->s, m->noinfer);
write_int8(s->s, m->constprop);
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
jl_serialize_value(s, (jl_value_t*)m->roots);
Expand Down Expand Up @@ -1526,6 +1527,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
m->isva = read_int8(s->s);
m->pure = read_int8(s->s);
m->is_for_opaque_closure = read_int8(s->s);
m->noinfer = read_int8(s->s);
m->constprop = read_int8(s->s);
m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms);
jl_gc_wb(m, m->slot_syms);
Expand Down
8 changes: 2 additions & 6 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -2053,10 +2053,8 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
intptr_t nspec = (mt == jl_type_type_mt || mt == jl_nonfunction_mt ? m->nargs + 1 : mt->max_args + 2);
jl_compilation_sig(ti, env, m, nspec, &newparams);
tt = (newparams ? jl_apply_tuple_type(newparams) : ti);
int is_compileable = ((jl_datatype_t*)ti)->isdispatchtuple ||
jl_isa_compileable_sig(tt, m);
JL_GC_POP();
return is_compileable ? (jl_value_t*)tt : jl_nothing;
return (jl_value_t*)tt;
}

// compile-time method lookup
Expand Down Expand Up @@ -2100,9 +2098,7 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES
}
else {
tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
if (tt != jl_nothing) {
nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
}
nf = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(19,
jl_perm_symsvec(20,
"code",
"codelocs",
"ssavaluetypes",
Expand All @@ -2348,8 +2348,9 @@ void jl_init_types(void) JL_GC_DISABLED
"inlineable",
"propagate_inbounds",
"pure",
"noinfer",
"constprop"),
jl_svec(19,
jl_svec(20,
jl_array_any_type,
jl_array_int32_type,
jl_any_type,
Expand All @@ -2368,14 +2369,15 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type),
jl_emptysvec,
0, 1, 19);
0, 1, 20);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(26,
jl_perm_symsvec(27,
"name",
"module",
"file",
Expand All @@ -2401,8 +2403,9 @@ void jl_init_types(void) JL_GC_DISABLED
"isva",
"pure",
"is_for_opaque_closure",
"noinfer",
"constprop"),
jl_svec(26,
jl_svec(27,
jl_symbol_type,
jl_module_type,
jl_symbol_type,
Expand All @@ -2428,6 +2431,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_uint8_type),
jl_emptysvec,
0, 1, 10);
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ typedef struct _jl_code_info_t {
uint8_t inlineable;
uint8_t propagate_inbounds;
uint8_t pure;
uint8_t noinfer;
// uint8 settings
uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none
} jl_code_info_t;
Expand Down Expand Up @@ -319,6 +320,7 @@ typedef struct _jl_method_t {
uint8_t isva;
uint8_t pure;
uint8_t is_for_opaque_closure;
uint8_t noinfer;
// uint8 settings
uint8_t constprop; // 0x00 = use heuristic; 0x01 = aggressive; 0x02 = none

Expand Down
3 changes: 2 additions & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1393,8 +1393,9 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym;
extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym;
extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym;
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym;
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *noinfer_sym;
extern jl_sym_t *aggressive_constprop_sym; extern jl_sym_t *no_constprop_sym;
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym;
extern jl_sym_t *macrocall_sym;
extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym;
extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym;
extern jl_sym_t *gc_preserve_begin_sym; extern jl_sym_t *gc_preserve_end_sym;
Expand Down
Loading

0 comments on commit 02c46d7

Please sign in to comment.