diff --git a/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl b/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl index 33ae71ff2f4..4429acb5fcb 100644 --- a/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl +++ b/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl @@ -1,4 +1,12 @@ using ..Orientations +import ..VariableTemplates: flattened_named_tuple +using ..VariableTemplates + +# Sometimes `NodalStack` returns local states +# that is `nothing`. Here, we return `nothing` +# to preserve the keys (e.g., `hyperdiff`) +# when misssing. +flattened_named_tuple(v::Nothing, ft::FlattenType = FlattenArr()) = nothing """ single_stack_diagnostics( diff --git a/src/Utilities/VariableTemplates/VariableTemplates.jl b/src/Utilities/VariableTemplates/VariableTemplates.jl index f0b296f5cb1..d4a1611a6b0 100644 --- a/src/Utilities/VariableTemplates/VariableTemplates.jl +++ b/src/Utilities/VariableTemplates/VariableTemplates.jl @@ -3,6 +3,7 @@ module VariableTemplates export varsize, Vars, Grad, @vars, varsindex, varsindices using StaticArrays +using LinearAlgebra """ varsindex(S, p::Symbol, [sp::Symbol...]) @@ -429,10 +430,12 @@ function getpropertyorindex end # Redirect to Base getproperty/getindex: Base.@propagate_inbounds getpropertyorindex(t::Tuple, ::Val{i}) where {i} = Base.getindex(t, i) -Base.@propagate_inbounds getpropertyorindex(a::SArray, ::Val{i}) where {i} = - Base.getindex(a, i) -Base.@propagate_inbounds getpropertyorindex(nt::AbstractVars, s::Symbol) = - Base.getproperty(nt, s) +Base.@propagate_inbounds getpropertyorindex( + a::AbstractArray, + ::Val{i}, +) where {i} = Base.getindex(a, i) +Base.@propagate_inbounds getpropertyorindex(v::AbstractVars, s::Symbol) = + Base.getproperty(v, s) Base.@propagate_inbounds getpropertyorindex( v::AbstractVars, ::Val{i}, @@ -443,8 +446,10 @@ Base.@propagate_inbounds getpropertyorindex( v::AbstractVars, t::Tuple{A}, ) where {A} = getpropertyorindex(v, t[1]) -Base.@propagate_inbounds getpropertyorindex(a::SArray, t::Tuple{A}) where {A} = - getpropertyorindex(a, t[1]) +Base.@propagate_inbounds getpropertyorindex( + a::AbstractArray, + t::Tuple{A}, +) where {A} = getpropertyorindex(a, t[1]) # Peel first element from tuple and recurse: Base.@propagate_inbounds getpropertyorindex(v::AbstractVars, t::Tuple) = diff --git a/src/Utilities/VariableTemplates/flattened_tup_chain.jl b/src/Utilities/VariableTemplates/flattened_tup_chain.jl index b6240d2be92..d9aae6ea942 100644 --- a/src/Utilities/VariableTemplates/flattened_tup_chain.jl +++ b/src/Utilities/VariableTemplates/flattened_tup_chain.jl @@ -1,5 +1,7 @@ +using LinearAlgebra + export flattened_tup_chain, flattened_named_tuple -export FlattenArr, RetainArr +export FlattenType, FlattenArr, RetainArr abstract type FlattenType end @@ -19,16 +21,22 @@ and `flattened_named_tuple`. """ struct RetainArr <: FlattenType end +# The Vars instance has many empty entries. +# Keeping all of the keys results in many +# duplicated values. So, it's best we +# "prune" the tree by removing the keys: flattened_tup_chain( ::Type{NamedTuple{(), Tuple{}}}, ::FlattenType = FlattenArr(); prefix = (Symbol(),), ) = () + flattened_tup_chain( ::Type{T}, ::FlattenType; prefix = (Symbol(),), ) where {T <: Real} = (prefix,) + flattened_tup_chain( ::Type{T}, ::RetainArr; @@ -42,9 +50,27 @@ flattened_tup_chain( flattened_tup_chain( ::Type{T}, - ::FlattenType; + ::RetainArr; prefix = (Symbol(),), ) where {T <: SHermitianCompact} = (prefix,) +flattened_tup_chain( + ::Type{T}, + ::FlattenType; + prefix = (Symbol(),), +) where {T <: SHermitianCompact} = + ntuple(i -> (prefix..., i), length(StaticArrays.lowertriangletype(T))) + +flattened_tup_chain( + ::Type{T}, + ::RetainArr; + prefix = (Symbol(),), +) where {N, TA, T <: Diagonal{N, TA}} = (prefix,) +flattened_tup_chain( + ::Type{T}, + ::FlattenArr; + prefix = (Symbol(),), +) where {N, TA, T <: Diagonal{N, TA}} = ntuple(i -> (prefix..., i), length(TA)) + flattened_tup_chain(::Type{T}, ::FlattenType; prefix = (Symbol(),)) where {T} = (prefix,) @@ -79,9 +105,10 @@ flattened_tup_chain( ) where {S} = flattened_tup_chain(S, ft) """ - flattened_named_tuple(v::AbstractVars, ::FlattenType) + flattened_named_tuple -A flattened NamedTuple, given a `Vars` instance. +A flattened NamedTuple, given a +`Vars` or nested `NamedTuple` instance. # Example: @@ -106,23 +133,40 @@ function flattened_named_tuple(v::AbstractVars, ft::FlattenType = FlattenArr()) ftc = flattened_tup_chain(v, ft) keys_ = Symbol.(join.(ftc, :_)) vals = map(x -> getproperty(v, wrap_val.(x)), ftc) + length(keys_) == length(vals) || error("key-value mismatch") return (; zip(keys_, vals)...) end -flattened_named_tuple(v::Nothing, ft::FlattenType = FlattenArr()) = NamedTuple() function flattened_named_tuple(nt::NamedTuple, ft::FlattenType = FlattenArr()) ftc = flattened_tup_chain(typeof(nt), ft) keys_ = Symbol.(join.(ftc, :_)) - vals = flattened_nt_vals(nt) + vals = flattened_nt_vals(ft, nt) + length(keys_) == length(vals) || error("key-value mismatch") return (; zip(keys_, vals)...) end -flattened_nt_vals(a::NamedTuple) = flattened_nt_vals(Tuple(a)) -flattened_nt_vals(a::NamedTuple{(), Tuple{}}) = (nothing,) -flattened_nt_vals(a) = (a,) -flattened_nt_vals(a::NamedTuple, b...) = - tuple(flattened_nt_vals(a)..., flattened_nt_vals(b...)...) -flattened_nt_vals(a::NamedTuple{(), Tuple{}}, b...) = - tuple(nothing, flattened_nt_vals(b...)...) -flattened_nt_vals(a, b...) = tuple(values(a), flattened_nt_vals(b...)...) -flattened_nt_vals(x::Tuple) = flattened_nt_vals(x...) + +flattened_nt_vals(::FlattenArr, a::AbstractArray) = tuple(a...) +flattened_nt_vals(::RetainArr, a::AbstractArray) = tuple(a) + +flattened_nt_vals(::FlattenArr, a::Diagonal) = tuple(a.diag...) +flattened_nt_vals(::RetainArr, a::Diagonal) = tuple(a.diag) + +flattened_nt_vals(::FlattenArr, a::SHermitianCompact) = + tuple(a.lowertriangle...) +flattened_nt_vals(::RetainArr, a::SHermitianCompact) = tuple(a.lowertriangle) + +# when we splat an empty tuple `b` into `flattened_nt_vals(ft, b...)` +flattened_nt_vals(::FlattenType) = () + +# for structs +flattened_nt_vals(::FlattenType, a) = (a,) + +# Divide and concur: +flattened_nt_vals(ft::FlattenType, a, b...) = + tuple(flattened_nt_vals(ft, a)..., flattened_nt_vals(ft, b...)...) + +flattened_nt_vals(ft::FlattenType, a::Tuple) = flattened_nt_vals(ft, a...) + +flattened_nt_vals(ft::FlattenType, a::NamedTuple) = + flattened_nt_vals(ft, Tuple(a)) diff --git a/test/Utilities/VariableTemplates/test_complex_models.jl b/test/Utilities/VariableTemplates/test_complex_models.jl index f3226186fd8..89e1d8de35c 100644 --- a/test/Utilities/VariableTemplates/test_complex_models.jl +++ b/test/Utilities/VariableTemplates/test_complex_models.jl @@ -2,6 +2,8 @@ using Test using StaticArrays using ClimateMachine.VariableTemplates using ClimateMachine.VariableTemplates: wrap_val +import ClimateMachine.VariableTemplates +VT = VariableTemplates @testset "Test complex models" begin include("complex_models.jl") @@ -91,6 +93,12 @@ using ClimateMachine.VariableTemplates: wrap_val end @test fn[j] === "scalar_model.x" + # flattened_tup_chain - empty/generic cases + struct Foo end + @test flattened_tup_chain(NamedTuple{(), Tuple{}}) == () + @test flattened_tup_chain(Foo, RetainArr()) == ((Symbol(),),) + @test flattened_tup_chain(Foo, FlattenArr()) == ((Symbol(),),) + # flattened_tup_chain - Retain arrays ftc = flattened_tup_chain(st, RetainArr()) @@ -230,4 +238,65 @@ using ClimateMachine.VariableTemplates: wrap_val @test fnt.vector_model_x_3 == 23.0 @test fnt.scalar_model_x == 24.0f0 + struct Foo end + nt = (; + nest = (; + v = SVector(1, 2, 3), + nt = (; + shc = SHermitianCompact{3, FT, 6}(collect(1:6)), + f = FT(1.0), + ), + d = SDiagonal(collect(1:3)...), + tt = (Foo(), Foo()), + t = Foo(), + ), + ) + # Test flattened_nt_vals: + + @test VT.flattened_nt_vals(RetainArr(), NamedTuple()) == () + @test VT.flattened_nt_vals(FlattenArr(), NamedTuple()) == () + @test VT.flattened_nt_vals(RetainArr(), Tuple(NamedTuple())) == () + @test VT.flattened_nt_vals(FlattenArr(), Tuple(NamedTuple())) == () + + ft = FlattenArr() + @test VT.flattened_nt_vals(ft, nt.nest.nt.f) == (1.0f0,) + @test VT.flattened_nt_vals(ft, nt.nest.nt) == + (1.0f0, 2.0f0, 3.0f0, 4.0f0, 5.0f0, 6.0f0, 1.0f0) + @test VT.flattened_nt_vals(ft, nt.nest.d) == (1, 2, 3) + @test VT.flattened_nt_vals(ft, nt.nest.t) == (Foo(),) + @test VT.flattened_nt_vals(ft, nt.nest.tt) == (Foo(), Foo()) + + ft = RetainArr() + @test VT.flattened_nt_vals(ft, nt.nest.nt.f) == (1.0f0,) + @test VT.flattened_nt_vals(ft, nt.nest.nt)[1] == + nt.nest.nt.shc.lowertriangle + @test VT.flattened_nt_vals(ft, nt.nest.nt)[2] == 1.0f0 + @test VT.flattened_nt_vals(ft, nt.nest.d) == (nt.nest.d.diag,) + @test VT.flattened_nt_vals(ft, nt.nest.t) == (Foo(),) + @test VT.flattened_nt_vals(ft, nt.nest.tt) == (Foo(), Foo()) + + # Test flattened_named_tuple for NamedTuples + fnt = flattened_named_tuple(nt, FlattenArr()) + @test fnt.nest_v_1 == 1 + @test fnt.nest_v_2 == 2 + @test fnt.nest_v_3 == 3 + @test fnt.nest_nt_shc_1 == 1.0 + @test fnt.nest_nt_shc_2 == 2.0 + @test fnt.nest_nt_shc_3 == 3.0 + @test fnt.nest_nt_shc_4 == 4.0 + @test fnt.nest_nt_shc_5 == 5.0 + @test fnt.nest_nt_shc_6 == 6.0 + @test fnt.nest_nt_f == 1.0 + @test fnt.nest_tt_1 == Foo() + @test fnt.nest_tt_2 == Foo() + @test fnt.nest_t == Foo() + + fnt = flattened_named_tuple(nt, RetainArr()) + @test fnt.nest_v == SVector(1, 2, 3) + @test fnt.nest_nt_shc == nt.nest.nt.shc.lowertriangle + @test fnt.nest_nt_f == 1.0 + @test fnt.nest_tt_1 == Foo() + @test fnt.nest_tt_2 == Foo() + @test fnt.nest_t == Foo() + end