Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend and improve RecursiveApply, fix DSS inference #1334

Merged
merged 10 commits into from
Jun 16, 2023
79 changes: 70 additions & 9 deletions src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,52 @@ module RecursiveApply

export ⊞, ⊠, ⊟

# These functions need to be generated for type stability (since T.parameters is
# a SimpleVector, the compiler cannot always infer its size and elements).
@generated first_param(::Type{T}) where {T} = :($(first(T.parameters)))
@generated tail_params(::Type{T}) where {T} =
:($(Tuple{Base.tail((T.parameters...,))...}))

# Applying `rmaptype` returns `Tuple{...}` for tuple
# types, which cannot follow the recursion pattern as
# it cannot be splatted, so we add a separate method,
# `rmaptype_Tuple`, for the part of the recursion.
rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = ()
rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} =
(rmaptype(fn, first_param(T)),)
rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} =
(rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...)

rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{Tuple{}}) = ()
rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = ()
rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = ()
rmaptype_Tuple(
fn::F,
::Type{T1},
::Type{T2},
) where {F, T1 <: Tuple, T2 <: Tuple} = (
rmaptype(fn, first_param(T1), first_param(T2)),
rmaptype_Tuple(fn, tail_params(T1), tail_params(T2))...,
)

"""
rmap(fn, X...)

Recursively apply `fn` to each element of `X`
"""
rmap(fn::F, X) where {F} = fn(X)
rmap(fn::F, X::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple) where {F} =
(rmap(fn, first(X)), rmap(fn, Base.tail(X))...)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))

rmap(fn::F, X, Y) where {F} = fn(X, Y)
rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X)
rmap(fn, X::Tuple{}, Y::Tuple{}) = ()
rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = ()
rmap(fn::F, X::Tuple, Y::Tuple) where {F} =
(rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...)
rmap(fn::F, X::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X)))
rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} =
NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y)))

Expand All @@ -32,17 +65,45 @@ rmax(X, Y) = rmap(max, X, Y)

"""
rmaptype(fn, T)
rmaptype(fn, T1, T2)

The return type of `rmap(fn, X::T)`.
Recursively apply `fn` to each type parameter of the type `T`, or to each type
parameter of the types `T1` and `T2`, where `fn` returns a type.
"""
rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T)
rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} =
Tuple{map(fn, tuple(T.parameters...))...}
Tuple{rmaptype_Tuple(fn, T)...}
rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} =
NamedTuple{names, rmaptype(fn, Tup)}

rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2)
rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} =
Tuple{rmaptype_Tuple(fn, T1, T2)...}
rmaptype(
fn::F,
::Type{T},
) where {F, T <: NamedTuple{names, tup}} where {names, tup} =
NamedTuple{names, rmaptype(fn, tup)}
::Type{T1},
::Type{T2},
) where {
F,
names,
Tup1,
Tup2,
T1 <: NamedTuple{names, Tup1},
T2 <: NamedTuple{names, Tup2},
} = NamedTuple{names, rmaptype(fn, Tup1, Tup2)}

"""
rzero(T)

Recursively compute the zero value of type `T`.
"""
rzero(::Type{T}) where {T} = zero(T)
rzero(::Type{Tuple{}}) = ()
rzero(::Type{T}) where {E, T <: Tuple{E}} = (rzero(E),)
rzero(::Type{T}) where {T <: Tuple} =
(rzero(first_param(T)), rzero(tail_params(T))...)
rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} =
NamedTuple{names}(rzero(T))

"""
rmul(X, Y)
Expand Down
7 changes: 3 additions & 4 deletions src/Spaces/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ function dss_local_vertices!(
sum_data = mapreduce(
⊞,
vertex;
init = RecursiveApply.rmap(zero, slab(perimeter_data, 1, 1)[1]),
init = RecursiveApply.rzero(eltype(slab(perimeter_data, 1, 1))),
) do (lidx, vert)
ip = Topologies.perimeter_vertex_node_index(vert)
perimeter_slab = slab(perimeter_data, level, lidx)
Expand Down Expand Up @@ -906,9 +906,8 @@ function dss_local_ghost!(
sum_data = mapreduce(
⊞,
vertex;
init = RecursiveApply.rmap(
zero,
slab(perimeter_data, 1, 1)[1],
init = RecursiveApply.rzero(
eltype(slab(perimeter_data, 1, 1)),
),
) do (isghost, idx, vert)
ip = Topologies.perimeter_vertex_node_index(vert)
Expand Down
4 changes: 2 additions & 2 deletions test/Operators/spectralelement/benchmark_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ using JET
p = @allocated kernel_complicated_field_dss!(kernel_args)
@test p == 0
p = @allocated kernel_complicated_field2_dss!(kernel_args)
@test_broken p == 0
@test p == 0
# Inference tests
JET.@test_opt kernel_scalar_dss!(kernel_args)
JET.@test_opt kernel_vector_dss!(kernel_args)
JET.@test_opt kernel_field_dss!(kernel_args)
JET.@test_opt kernel_ntuple_field_dss!(kernel_args)
JET.@test_opt kernel_ntuple_floats_dss!(kernel_args)
JET.@test_opt kernel_complicated_field_dss!(kernel_args)
# JET.@test_opt kernel_complicated_field2_dss!(kernel_args) # fails
JET.@test_opt kernel_complicated_field2_dss!(kernel_args)
end
end
46 changes: 46 additions & 0 deletions test/RecursiveApply/recursive_apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,49 @@ end
RecursiveApply.rmul(x, FT(2))
end
end

@testset "Highly nested types" begin
FT = Float64
nested_types = [
FT,
Tuple{FT, FT},
NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}},
Tuple{
NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}},
NamedTuple{(:ϕ, :ψ), Tuple{FT, FT}},
},
Tuple{FT, FT},
NamedTuple{
(:ρ, :uₕ, :ρe_tot, :ρq_tot, :sgs⁰, :sgsʲs),
Tuple{
FT,
Tuple{FT, FT},
FT,
FT,
NamedTuple{(:ρatke,), Tuple{FT}},
Tuple{NamedTuple{(:ρa, :ρae_tot, :ρaq_tot), Tuple{FT, FT, FT}}},
},
},
NamedTuple{
(:u₃, :sgsʲs),
Tuple{Tuple{FT}, Tuple{NamedTuple{(:u₃,), Tuple{Tuple{FT}}}}},
},
]
for nt in nested_types
rz = RecursiveApply.rmap(RecursiveApply.rzero, nt)
@test typeof(rz) == nt
@inferred RecursiveApply.rmap(RecursiveApply.rzero, nt)

rz = RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt)
@test typeof(rz) == nt
@inferred RecursiveApply.rmap((x, y) -> RecursiveApply.rzero(x), nt, nt)

rz = RecursiveApply.rmaptype(identity, nt)
@test rz == nt
@inferred RecursiveApply.rmaptype(zero, nt)

rz = RecursiveApply.rmaptype((x, y) -> identity(x), nt, nt)
@test rz == nt
@inferred RecursiveApply.rmaptype((x, y) -> zero(x), nt, nt)
end
end
8 changes: 4 additions & 4 deletions test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ using Aqua
# If the number of ambiguities is less than the limit below,
# then please lower the limit based on the new number of ambiguities.
# We're trying to drive this number down to zero to reduce latency.
@test length(ambs) ≤ 15
# Uncomment for debugging:
# for method_ambiguity in ambs
# @show method_ambiguity
# end
for method_ambiguity in ambs
@show method_ambiguity
end
@test length(ambs) ≤ 16
end

@testset "Aqua tests (additional)" begin
Expand Down