From d406c7ef1bf2b4f33876e2643bcf0c157f79c053 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Fri, 3 Dec 2021 20:25:54 +0800 Subject: [PATCH] make flattened `Broadcasted` more compiler friendly. 1. make `cat_nested` better inferred by switching to direct self-recursion. 2. `make_makeargs` now create a tuple of functions which take in the whole argument list and return the corresponding input for the broadcasted function. --- base/broadcast.jl | 121 ++++++++++++++++------------------------------ test/broadcast.jl | 19 ++++++-- 2 files changed, 59 insertions(+), 81 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 1e057789509ed..314c96c0cf8c1 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -341,20 +341,16 @@ function flatten(bc::Broadcasted) isflat(bc) && return bc # concatenate the nested arguments into {a, b, c, d} args = cat_nested(bc) - # build a function `makeargs` that takes a "flat" argument list and - # and creates the appropriate input arguments for `f`, e.g., - # makeargs = (w, x, y, z) -> (w, g(x, y), z) - # - # `makeargs` is built recursively and looks a bit like this: - # makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...) - # = (w, g(x, y), makeargs2(z)...) - # = (w, g(x, y), z) - let makeargs = make_makeargs(()->(), bc.args), f = bc.f - newf = @inline function(args::Vararg{Any,N}) where N - f(makeargs(args...)...) - end - return Broadcasted(bc.style, newf, args, bc.axes) - end + # build a tuple of functions `makeargs`. Its elements take + # the whole "flat" argument list and and generate the appropriate + # input arguments for the broadcasted function `f`, e.g., + # makeargs[1] = ((w, x, y, z)) -> w + # makeargs[2] = ((w, x, y, z)) -> g(x, y) + # makeargs[3] = ((w, x, y, z)) -> z + makeargs = make_makeargs(bc.args) + f = Base.maybeconstructor(bc.f) + newf = (args...) -> (@inline; f(prepare_args(makeargs, args)...)) + return Broadcasted(bc.style, newf, args, bc.axes) end const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}} @@ -363,78 +359,47 @@ _isflat(args::NestedTuple) = false _isflat(args::Tuple) = _isflat(tail(args)) _isflat(args::Tuple{}) = true -cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...) -cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...) -cat_nested() = () +cat_nested(bc::Broadcasted) = cat_nested_args(bc.args) +cat_nested_args(::Tuple{}) = () +cat_nested_args(t::Tuple{Any}) = cat_nested(t[1]) +cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...) +cat_nested(a) = (a,) """ - make_makeargs(makeargs_tail::Function, t::Tuple) -> Function + make_makeargs(t::Tuple) -> Tuple{Vararg{Function}} Each element of `t` is one (consecutive) node in a broadcast tree. -Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is -to return a function that takes in flattened argument list and returns a -tuple (each entry corresponding to an entry in `t`, having evaluated -the corresponding element in the broadcast tree). As an additional -complication, the passed in tuple may be longer than the number of leaves -in the subtree described by `t`. The `makeargs_tail` function should -be called on such additional arguments (but not the arguments consumed -by `t`). +The returned `Tuple` are functions which take in the (whole) flattened +list and generate the inputs for the corresponding broadcasted function. """ -@inline make_makeargs(makeargs_tail, t::Tuple{}) = makeargs_tail -@inline function make_makeargs(makeargs_tail, t::Tuple) - makeargs = make_makeargs(makeargs_tail, tail(t)) - (head, tail...)->(head, makeargs(tail...)...) +make_makeargs(args::Tuple) = _make_makeargs(args, 1)[1] + +# We build `makeargs` by traversing the broadcast nodes recursively. +# note: `n` indicates the flattened index of the next unused argument. +@inline function _make_makeargs(args::Tuple, n::Int) + head, n = _make_makeargs1(args[1], n) + rest, n = _make_makeargs(tail(args), n) + (head, rest...), n end -function make_makeargs(makeargs_tail, t::Tuple{<:Broadcasted, Vararg{Any}}) - bc = t[1] - # c.f. the same expression in the function on leaf nodes above. Here - # we recurse into siblings in the broadcast tree. - let makeargs_tail = make_makeargs(makeargs_tail, tail(t)), - # Here we recurse into children. It would be valid to pass in makeargs_tail - # here, and not use it below. However, in that case, our recursion is no - # longer purely structural because we're building up one argument (the closure) - # while destructuing another. - makeargs_head = make_makeargs((args...)->args, bc.args), - f = bc.f - # Create two functions, one that splits of the first length(bc.args) - # elements from the tuple and one that yields the remaining arguments. - # N.B. We can't call headargs on `args...` directly because - # args is flattened (i.e. our children have not been evaluated - # yet). - headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) - return @inline function(args::Vararg{Any,N}) where N - args1 = makeargs_head(args...) - a, b = headargs(args1...), makeargs_tail(tailargs(args1...)...) - (f(a...), b...) - end - end +_make_makeargs(::Tuple{}, n::Int) = (), n + +# A help struct to store the flattened index staticly +struct Pick{N} <: Function end +(::Pick{N})(@nospecialize(args::Tuple)) where {N} = args[N] + +# For flat nodes, we just consume one argument (n += 1), and return the "Pick" function +@inline _make_makeargs1(_, n::Int) = Pick{n}(), n + 1 +# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc))) +@inline function _make_makeargs1(bc::Broadcasted, n::Int) + makeargs, n = _make_makeargs(bc.args, n) + f = Base.maybeconstructor(bc.f) + makeargs1 = (args::Tuple) -> (@inline; f(prepare_args(makeargs, args)...)) + makeargs1, n end -@inline function make_headargs(t::Tuple) - let headargs = make_headargs(tail(t)) - return @inline function(head, tail::Vararg{Any,N}) where N - (head, headargs(tail...)...) - end - end -end -@inline function make_headargs(::Tuple{}) - return @inline function(tail::Vararg{Any,N}) where N - () - end -end - -@inline function make_tailargs(t::Tuple) - let tailargs = make_tailargs(tail(t)) - return @inline function(head, tail::Vararg{Any,N}) where N - tailargs(tail...) - end - end -end -@inline function make_tailargs(::Tuple{}) - return @inline function(tail::Vararg{Any,N}) where N - tail - end -end +@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...) +@inline prepare_args(makeargs::Tuple{Any}, @nospecialize(x::Tuple)) = (makeargs[1](x),) +prepare_args(::Tuple{}, ::Tuple) = () ## Broadcasting utilities ## diff --git a/test/broadcast.jl b/test/broadcast.jl index 87858dd0f08fc..6cf05fbea139c 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -774,14 +774,27 @@ let X = zeros(2, 3) end # issue #27988: inference of Broadcast.flatten -using .Broadcast: Broadcasted +using .Broadcast: Broadcasted, cat_nested let bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5)))) - @test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5) + @test @inferred(cat_nested(bc)) == (1,2,3,4,5) @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62 bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5)))) - @test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5) + @test @inferred(cat_nested(bc)) == (1,2.0,2.5,3,4,5) @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8 + # 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3 + bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}()))))) + @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2 + # @. 1 + 1 * (1 + 1 + 1 + 1) + bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1)))))) + @test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1) # `cat_nested` failed to infer this + @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc) + # @. 1 + (1 + 1) + 1 + (1 + 1) + 1 + (1 + 1) + 1 + bc = Broadcasted(+, (1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1)) + @test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1, 1, 1, 1, 1) + @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc) + bc = Broadcasted(Float32, (Broadcasted(+, (1, 1)),)) + @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc) end let