diff --git a/docs/src/APIs/Utilities/VariableTemplates.md b/docs/src/APIs/Utilities/VariableTemplates.md index 7b0ba6058c1..a28252d784f 100644 --- a/docs/src/APIs/Utilities/VariableTemplates.md +++ b/docs/src/APIs/Utilities/VariableTemplates.md @@ -19,6 +19,8 @@ vuntuple unroll_map flattened_tup_chain flattened_named_tuple +FlattenArr +RetainArr varsize varsindices varsindex diff --git a/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl b/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl index d9fd79888b5..33ae71ff2f4 100644 --- a/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl +++ b/src/Utilities/SingleStackUtils/single_stack_diagnostics.jl @@ -54,31 +54,16 @@ function single_stack_diagnostics( z = altitude(bl, aux) - # TODO: Use flatten_named_tuple = true, and flatten arrays - flatten_named_tuple = false - if flatten_named_tuple - nt = (; - z = altitude(bl, aux), - prog = flattened_named_tuple(prog), # Vars -> flattened NamedTuples - aux = flattened_named_tuple(aux), # Vars -> flattened NamedTuples - ∇flux = flattened_named_tuple(∇flux), # Vars -> flattened NamedTuples - hyperdiff = flattened_named_tuple(hyperdiff), # Vars -> flattened NamedTuples - cache = cache, - ) - # Flatten top level: - flattened_named_tuple(nt) - else - nt = (; - z = altitude(bl, aux), - prog = prog, - aux = aux, - ∇flux = ∇flux, - hyperdiff = hyperdiff, - cache = cache, - ) - nt - end - nt + nt = (; + z = altitude(bl, aux), + prog = flattened_named_tuple(prog), # Vars -> flattened NamedTuples + aux = flattened_named_tuple(aux), # Vars -> flattened NamedTuples + ∇flux = flattened_named_tuple(∇flux), # Vars -> flattened NamedTuples + hyperdiff = flattened_named_tuple(hyperdiff), # Vars -> flattened NamedTuples + cache = cache, + ) + # Flatten top level: + flattened_named_tuple(nt) end for local_states in NodalStack(bl, grid; kwargs...) ] end diff --git a/src/Utilities/VariableTemplates/VariableTemplates.jl b/src/Utilities/VariableTemplates/VariableTemplates.jl index 4cfa07f4e57..f0b296f5cb1 100644 --- a/src/Utilities/VariableTemplates/VariableTemplates.jl +++ b/src/Utilities/VariableTemplates/VariableTemplates.jl @@ -61,22 +61,30 @@ wrap_val(i::Int) = Val(i) # This means that users _must_ wrap `sym` # in `Val`, which can be done with `wrap_val` # above. -unval(::Val{i}) where {i} = i Base.@propagate_inbounds function varsindex( ::Type{S}, - sym, + sym::Symbol, rest..., ) where {S <: Union{NamedTuple, Tuple}} - if sym isa Symbol - vi = varsindex(fieldtype(S, sym), rest...) - return varsindex(S, sym)[vi] - else - i = unval(sym) - et = eltype(S) - offset = (i - 1) * varsize(et) - vi = varsindex(et, rest...) - return (vi.start + offset):(vi.stop + offset) - end + vi = varsindex(fieldtype(S, sym), rest...) + return varsindex(S, sym)[vi] +end +Base.@propagate_inbounds function varsindex( + ::Type{S}, + ::Val{i}, + rest..., +) where {i, S <: Union{NamedTuple, Tuple}} + et = eltype(S) + offset = (i - 1) * varsize(et) + vi = varsindex(et, rest...) + return (vi.start + offset):(vi.stop + offset) +end + +Base.@propagate_inbounds function varsindex( + ::Type{S}, + ::Val{i}, +) where {i, S <: SArray} + return i:i end """ @@ -391,70 +399,62 @@ vuntuple(f::F, N::Int) where {F} = ntuple(i -> f(Val(i)), Val(N)) # Inside unroll_map expressions, all indexes `i` # are wrapped in `Val`, so we must redirect # these methods: -Base.getindex(t::Tuple, ::Val{i}) where {i} = Base.getindex(t, i) -Base.getindex(a::SArray, ::Val{i}) where {i} = Base.getindex(a, i) - -# Somehow needed for GPU... -Base.@propagate_inbounds Base.getindex(v::AbstractVars, i::Int) = - Base.getindex(v, Val(i)) +Base.@propagate_inbounds Base.getindex(t::Tuple, ::Val{i}) where {i} = + Base.getindex(t, i) +Base.@propagate_inbounds Base.getindex(a::SArray, ::Val{i}) where {i} = + Base.getindex(a, i) Base.@propagate_inbounds function Base.getindex( - v::AbstractVars{NTuple{N, T}, A, offset}, + v::Vars{NTuple{N, T}, A, offset}, ::Val{i}, -) where {N, T, A, offset, i} - # 1 <= i <= N - array = parent(v) - if v isa Vars - return Vars{T, A, offset + (i - 1) * varsize(T)}(array) - else - return Grad{T, A, offset + (i - 1) * varsize(T)}(array) - end -end - -Base.@propagate_inbounds function Base.getproperty( - v::AbstractVars, - tup_chain::Tuple{S}, -) where {S <: Symbol} - return Base.getproperty(v, tup_chain[1]) +) where {N, T, A, offset, i} # 1 <= i <= N + return Vars{T, A, offset + (i - 1) * varsize(T)}(parent(v)) end Base.@propagate_inbounds function Base.getindex( - v::AbstractVars, - tup_chain::Tuple{S}, -) where {S <: Int} - return Base.getindex(v, Val(tup_chain[1])) + v::Grad{NTuple{N, T}, A, offset}, + ::Val{i}, +) where {N, T, A, offset, i} # 1 <= i <= N + return Grad{T, A, offset + (i - 1) * varsize(T)}(parent(v)) end -Base.@propagate_inbounds function Base.getproperty( +""" + getpropertyorindex + +An interchangeably and nested-friendly +`getproperty`/`getindex`. +""" +function getpropertyorindex end + +# Redirect to Base getproperty/getindex: +Base.@propagate_inbounds getpropertyorindex(t::Tuple, ::Val{i}) where {i} = + Base.getindex(t, i) +Base.@propagate_inbounds getpropertyorindex(a::SArray, ::Val{i}) where {i} = + Base.getindex(a, i) +Base.@propagate_inbounds getpropertyorindex(nt::AbstractVars, s::Symbol) = + Base.getproperty(nt, s) +Base.@propagate_inbounds getpropertyorindex( v::AbstractVars, - tup_chain::Tuple, -) - if tup_chain[1] isa Int - p = Base.getindex(v, Val(tup_chain[1])) - else - p = Base.getproperty(v, tup_chain[1]) - end - if tup_chain[2] isa Int - return Base.getindex(p, tup_chain[2:end]) - else - return Base.getproperty(p, tup_chain[2:end]) - end -end -Base.@propagate_inbounds function Base.getindex( + ::Val{i}, +) where {i} = Base.getindex(v, Val(i)) + +# Only one element left: +Base.@propagate_inbounds getpropertyorindex( v::AbstractVars, - tup_chain::Tuple, -) - if tup_chain[1] isa Int - p = Base.getindex(v, Val(tup_chain[1])) - else - p = Base.getproperty(v, tup_chain[1]) - end - if tup_chain[2] isa Int - return Base.getindex(p, tup_chain[2:end]) - else - return Base.getproperty(p, tup_chain[2:end]) - end -end + t::Tuple{A}, +) where {A} = getpropertyorindex(v, t[1]) +Base.@propagate_inbounds getpropertyorindex(a::SArray, t::Tuple{A}) where {A} = + getpropertyorindex(a, t[1]) + +# Peel first element from tuple and recurse: +Base.@propagate_inbounds getpropertyorindex(v::AbstractVars, t::Tuple) = + getpropertyorindex(getpropertyorindex(v, t[1]), Tuple(t[2:end])) + +# Redirect to getpropertyorindex: +Base.@propagate_inbounds Base.getproperty(v::AbstractVars, tup_chain::Tuple) = + getpropertyorindex(v, tup_chain) +Base.@propagate_inbounds Base.getindex(v::AbstractVars, tup_chain::Tuple) = + getpropertyorindex(v, tup_chain) include("flattened_tup_chain.jl") diff --git a/src/Utilities/VariableTemplates/flattened_tup_chain.jl b/src/Utilities/VariableTemplates/flattened_tup_chain.jl index 78eec29132f..37b2343b98d 100644 --- a/src/Utilities/VariableTemplates/flattened_tup_chain.jl +++ b/src/Utilities/VariableTemplates/flattened_tup_chain.jl @@ -1,15 +1,38 @@ export flattened_tup_chain, flattened_named_tuple +export FlattenArr, RetainArr # should this be exported? -flattened_tup_chain(::Type{NamedTuple{(), Tuple{}}}; prefix = (Symbol(),)) = () -flattened_tup_chain(::Type{T}; prefix = (Symbol(),)) where {T <: Real} = - (prefix,) -flattened_tup_chain(::Type{T}; prefix = (Symbol(),)) where {T <: SArray} = - (prefix,) +abstract type FlattenType end +struct FlattenArr <: FlattenType end +struct RetainArr <: FlattenType end + +flattened_tup_chain( + ::Type{NamedTuple{(), Tuple{}}}, + ::FlattenType = FlattenArr(); + prefix = (Symbol(),), +) = () +flattened_tup_chain( + ::Type{T}, + ::FlattenType; + prefix = (Symbol(),), +) where {T <: Real} = (prefix,) flattened_tup_chain( - ::Type{T}; + ::Type{T}, + ::RetainArr; + prefix = (Symbol(),), +) where {T <: SArray} = (prefix,) +flattened_tup_chain( + ::Type{T}, + ::FlattenArr; + prefix = (Symbol(),), +) where {T <: SArray} = ntuple(i -> (prefix..., i), length(T)) + +flattened_tup_chain( + ::Type{T}, + ::FlattenType; prefix = (Symbol(),), ) where {T <: SHermitianCompact} = (prefix,) -flattened_tup_chain(::Type{T}; prefix = (Symbol(),)) where {T} = (prefix,) +flattened_tup_chain(::Type{T}, ::FlattenType; prefix = (Symbol(),)) where {T} = + (prefix,) """ flattened_tup_chain(::Type{T}) where {T <: Union{NamedTuple,NTuple}} @@ -19,7 +42,8 @@ and integers for every combination of each field in the `Vars` array. """ function flattened_tup_chain( - ::Type{T}; + ::Type{T}, + ft::FlattenType = FlattenArr(); prefix = (Symbol(),), ) where {T <: Union{NamedTuple, NTuple}} map(1:fieldcount(T)) do i @@ -28,16 +52,20 @@ function flattened_tup_chain( sname = name isa Int ? name : Symbol(name) flattened_tup_chain( Ti, + ft; prefix = prefix == (Symbol(),) ? (sname,) : (prefix..., sname), ) end |> Iterators.flatten |> collect end -flattened_tup_chain(::AbstractVars{S}) where {S} = flattened_tup_chain(S) +flattened_tup_chain( + ::AbstractVars{S}, + ft::FlattenType = FlattenArr(), +) where {S} = flattened_tup_chain(S, ft) """ - flattened_named_tuple(v::AbstractVars) + flattened_named_tuple(v::AbstractVars, ::FlattenType) A flattened NamedTuple, given a `Vars` instance. @@ -60,16 +88,16 @@ fnt = flattened_named_tuple(nt); """ function flattened_named_tuple end -function flattened_named_tuple(v::AbstractVars) - ftc = flattened_tup_chain(v) +function flattened_named_tuple(v::AbstractVars, ft::FlattenType = FlattenArr()) + ftc = flattened_tup_chain(v, ft) keys_ = Symbol.(join.(ftc, :_)) - vals = map(x -> getproperty(v, x), ftc) + vals = map(x -> getproperty(v, wrap_val.(x)), ftc) return (; zip(keys_, vals)...) end -flattened_named_tuple(v::Nothing) = NamedTuple() +flattened_named_tuple(v::Nothing, ft::FlattenType = FlattenArr()) = NamedTuple() -function flattened_named_tuple(nt::NamedTuple) - ftc = flattened_tup_chain(typeof(nt)) +function flattened_named_tuple(nt::NamedTuple, ft::FlattenType = FlattenArr()) + ftc = flattened_tup_chain(typeof(nt), ft) keys_ = Symbol.(join.(ftc, :_)) vals = flattened_nt_vals(nt) return (; zip(keys_, vals)...) diff --git a/test/Atmos/EDMF/closures/entr_detr.jl b/test/Atmos/EDMF/closures/entr_detr.jl index 7e071ea23c5..7651323ba55 100644 --- a/test/Atmos/EDMF/closures/entr_detr.jl +++ b/test/Atmos/EDMF/closures/entr_detr.jl @@ -9,7 +9,7 @@ function entr_detr( env, buoy, ) where {FT} - EΔ_up = ntuple(n_updrafts(bl.turbconv)) do i + EΔ_up = vuntuple(n_updrafts(bl.turbconv)) do i entr_detr(bl, bl.turbconv.entr_detr, state, aux, ts_up, ts_en, env, buoy, i) end E_dyn, Δ_dyn, E_trb = ntuple(i -> map(x -> x[i], EΔ_up), 3) diff --git a/test/Atmos/EDMF/closures/pressure.jl b/test/Atmos/EDMF/closures/pressure.jl index 29047b4b4df..4216b9ca1f4 100644 --- a/test/Atmos/EDMF/closures/pressure.jl +++ b/test/Atmos/EDMF/closures/pressure.jl @@ -1,7 +1,7 @@ #### Pressure model kernels function perturbation_pressure(bl::AtmosModel{FT}, args, env, buoy) where {FT} - dpdz = ntuple(n_updrafts(bl.turbconv)) do i + dpdz = vuntuple(n_updrafts(bl.turbconv)) do i perturbation_pressure(bl, bl.turbconv.pressure, args, env, buoy, i) end return dpdz diff --git a/test/Atmos/EDMF/compute_mse.jl b/test/Atmos/EDMF/compute_mse.jl index 3997fb067ce..e63f5d01c33 100644 --- a/test/Atmos/EDMF/compute_mse.jl +++ b/test/Atmos/EDMF/compute_mse.jl @@ -29,8 +29,6 @@ PyCLES_output_dataset = ArtifactWrapper( ], ) PyCLES_output_dataset_path = get_data_folder(PyCLES_output_dataset) -data_files = Dict() -data_files[:Bomex] = Dataset(joinpath(PyCLES_output_dataset_path, "Bomex.nc"), "r") # data_files[:Rico] = Dataset(joinpath(PyCLES_output_dataset_path, "Rico.nc"), "r") # data_files[:Gabls] = Dataset(joinpath(PyCLES_output_dataset_path, "Gabls.nc"), "r") # data_files[:DYCOMS_RF01] = Dataset(joinpath(PyCLES_output_dataset_path, "DYCOMS_RF01.nc"), "r") @@ -222,3 +220,21 @@ function compute_mse( return mse end + +sufficient_mse(computed_mse, best_mse) = computed_mse <= best_mse + sqrt(eps()) + +function test_mse(computed_mse, best_mse, key) + mse_not_regressed = sufficient_mse(computed_mse[key], best_mse[key]) + @test mse_not_regressed + mse_not_regressed || @show key +end + +function dons(diag_vs_z) + return Dict(map(keys(first(diag_vs_z))) do k + string(k) => [getproperty(ca, k) for ca in diag_vs_z] + end) +end + +get_dons_arr(diag_arr) = [dons(diag_vs_z) for diag_vs_z in diag_arr] + +dons_arr = get_dons_arr(diag_arr) diff --git a/test/Atmos/EDMF/report_mse_bomex.jl b/test/Atmos/EDMF/report_mse_bomex.jl index 7524686efe8..82233bfc39a 100644 --- a/test/Atmos/EDMF/report_mse_bomex.jl +++ b/test/Atmos/EDMF/report_mse_bomex.jl @@ -9,72 +9,45 @@ end include(joinpath(@__DIR__, "compute_mse.jl")) -#! format: off -best_mse = Dict() -best_mse[:Bomex] = Dict() -best_mse[:Bomex]["ρ"] = 3.4917543567416755e-02 -best_mse[:Bomex]["ρu[1]"] = 3.0715061616086027e+03 -best_mse[:Bomex]["ρu[2]"] = 1.2895273328644972e-03 -best_mse[:Bomex]["moisture.ρq_tot"] = 4.1330591681441348e-02 -best_mse[:Bomex]["turbconv.environment.ρatke"] = 6.6415719930880925e+02 -best_mse[:Bomex]["turbconv.environment.ρaθ_liq_cv"] = 8.5667223192888514e+01 -best_mse[:Bomex]["turbconv.environment.ρaq_tot_cv"] = 1.6435555167634794e+02 -best_mse[:Bomex]["turbconv.updraft[1].ρa"] = 7.9564915645201182e+01 -best_mse[:Bomex]["turbconv.updraft[1].ρaw"] = 8.4288782126742318e-02 -best_mse[:Bomex]["turbconv.updraft[1].ρaθ_liq"] = 9.0095910670762631e+00 -best_mse[:Bomex]["turbconv.updraft[1].ρaq_tot"] = 1.0768554319447651e+01 -#! format: on - -sufficient_mse(computed_mse, best_mse) = computed_mse <= best_mse + sqrt(eps()) - -function test_mse(computed_mse, best_mse, key) - mse_not_regressed = sufficient_mse(computed_mse[key], best_mse[key]) - @test mse_not_regressed - mse_not_regressed || @show key -end +data_file = Dataset(joinpath(PyCLES_output_dataset_path, "Bomex.nc"), "r") #! format: off -dons(diag_vs_z) = Dict( - "z" => [ca.z for ca in diag_vs_z], - "ρ" => [ca.prog.ρ for ca in diag_vs_z], - "ρu[1]" => [ca.prog.ρu[1] for ca in diag_vs_z], - "moisture.ρq_tot" => [ca.prog.moisture.ρq_tot for ca in diag_vs_z], - "turbconv.updraft[1].ρa" => [ca.prog.turbconv.updraft[1].ρa for ca in diag_vs_z], - "turbconv.updraft[1].ρaw" => [ca.prog.turbconv.updraft[1].ρaw for ca in diag_vs_z], - "turbconv.updraft[1].ρaθ_liq" => [ca.prog.turbconv.updraft[1].ρaθ_liq for ca in diag_vs_z], - "turbconv.updraft[1].ρaq_tot" => [ca.prog.turbconv.updraft[1].ρaq_tot for ca in diag_vs_z], - "turbconv.environment.ρatke" => [ca.prog.turbconv.environment.ρatke for ca in diag_vs_z], - "turbconv.environment.ρaθ_liq_cv" => [ca.prog.turbconv.environment.ρaθ_liq_cv for ca in diag_vs_z], - "turbconv.environment.ρaq_tot_cv" => [ca.prog.turbconv.environment.ρaq_tot_cv for ca in diag_vs_z], -) +best_mse = Dict() +best_mse["prog_ρ"] = 3.4917543567416755e-02 +best_mse["prog_ρu_1"] = 3.0715061616086027e+03 +best_mse["prog_ρu_2"] = 1.2895273328644972e-03 +best_mse["prog_moisture_ρq_tot"] = 4.1330591681441348e-02 +best_mse["prog_turbconv_environment_ρatke"] = 6.6415719930880925e+02 +best_mse["prog_turbconv_environment_ρaθ_liq_cv"] = 8.5667223192888514e+01 +best_mse["prog_turbconv_environment_ρaq_tot_cv"] = 1.6435555167634794e+02 +best_mse["prog_turbconv_updraft_1_ρa"] = 7.9564915645201182e+01 +best_mse["prog_turbconv_updraft_1_ρaw"] = 8.4288782126742318e-02 +best_mse["prog_turbconv_updraft_1_ρaθ_liq"] = 9.0095910670762631e+00 +best_mse["prog_turbconv_updraft_1_ρaq_tot"] = 1.0768554319447651e+01 #! format: on -dons_arr = [dons(diag_vs_z) for diag_vs_z in diag_arr] - -computed_mse = Dict( - k => compute_mse( - solver_config.dg.grid, - solver_config.dg.balance_law, - time_data, - dons_arr, - data_files[k], - k, - best_mse[k], - plot_dir, - ) for k in keys(data_files) +computed_mse = compute_mse( + solver_config.dg.grid, + solver_config.dg.balance_law, + time_data, + dons_arr, + data_file, + "Bomex", + best_mse, + plot_dir, ) @testset "BOMEX EDMF Solution Quality Assurance (QA) tests" begin #! format: off - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "ρ") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "ρu[1]") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "moisture.ρq_tot") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.updraft[1].ρa") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.updraft[1].ρaw") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.updraft[1].ρaθ_liq") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.updraft[1].ρaq_tot") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.environment.ρatke") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.environment.ρaθ_liq_cv") - test_mse(computed_mse[:Bomex], best_mse[:Bomex], "turbconv.environment.ρaq_tot_cv") + test_mse(computed_mse, best_mse, "prog_ρ") + test_mse(computed_mse, best_mse, "prog_ρu_1") + test_mse(computed_mse, best_mse, "prog_moisture_ρq_tot") + test_mse(computed_mse, best_mse, "prog_turbconv_updraft_1_ρa") + test_mse(computed_mse, best_mse, "prog_turbconv_updraft_1_ρaw") + test_mse(computed_mse, best_mse, "prog_turbconv_updraft_1_ρaθ_liq") + test_mse(computed_mse, best_mse, "prog_turbconv_updraft_1_ρaq_tot") + test_mse(computed_mse, best_mse, "prog_turbconv_environment_ρatke") + test_mse(computed_mse, best_mse, "prog_turbconv_environment_ρaθ_liq_cv") + test_mse(computed_mse, best_mse, "prog_turbconv_environment_ρaq_tot_cv") #! format: on end diff --git a/test/Atmos/EDMF/variable_map.jl b/test/Atmos/EDMF/variable_map.jl index 1df6ac5e44e..4a5fcc863e7 100644 --- a/test/Atmos/EDMF/variable_map.jl +++ b/test/Atmos/EDMF/variable_map.jl @@ -2,15 +2,15 @@ var_map(s::String) = var_map(Val(Symbol(s))) var_map(::Val{T}) where {T} = nothing -var_map(::Val{Symbol("ρ")}) = ("rho", ()) -var_map(::Val{Symbol("ρu[1]")}) = ("u_mean", (:ρ,)) -var_map(::Val{Symbol("ρu[2]")}) = ("v_mean", (:ρ,)) -var_map(::Val{Symbol("moisture.ρq_tot")}) = ("qt_mean", (:ρ,)) -var_map(::Val{Symbol("turbconv.updraft[1].ρa")}) = ("updraft_fraction", (:ρ,)) -var_map(::Val{Symbol("turbconv.updraft[1].ρaw")}) = ("updraft_w", (:ρ, :a)) -var_map(::Val{Symbol("turbconv.updraft[1].ρaq_tot")}) = ("updraft_qt", (:ρ, :a)) -var_map(::Val{Symbol("turbconv.updraft[1].ρaθ_liq")}) = ("updraft_thetali", (:ρ, :a)) -var_map(::Val{Symbol("turbconv.environment.ρatke")}) = ("tke_mean", (:ρ, :a)) -var_map(::Val{Symbol("turbconv.environment.ρaθ_liq_cv")}) = ("env_thetali2", (:ρ, :a)) -var_map(::Val{Symbol("turbconv.environment.ρaq_tot_cv")}) = ("env_qt2", (:ρ, :a)) +var_map(::Val{:prog_ρ}) = ("rho", ()) +var_map(::Val{:prog_ρu_1}) = ("u_mean", (:ρ,)) +var_map(::Val{:prog_ρu_2}) = ("v_mean", (:ρ,)) +var_map(::Val{:prog_moisture_ρq_tot}) = ("qt_mean", (:ρ,)) +var_map(::Val{:prog_turbconv_updraft_1_ρa}) = ("updraft_fraction", (:ρ,)) +var_map(::Val{:prog_turbconv_updraft_1_ρaw}) = ("updraft_w", (:ρ, :a)) +var_map(::Val{:prog_turbconv_updraft_1_ρaq_tot}) = ("updraft_qt", (:ρ, :a)) +var_map(::Val{:prog_turbconv_updraft_1_ρaθ_liq}) = ("updraft_thetali", (:ρ, :a)) +var_map(::Val{:prog_turbconv_environment_ρatke}) = ("tke_mean", (:ρ, :a)) +var_map(::Val{:prog_turbconv_environment_ρaθ_liq_cv}) = ("env_thetali2", (:ρ, :a)) +var_map(::Val{:prog_turbconv_environment_ρaq_tot_cv}) = ("env_qt2", (:ρ, :a)) #! format: on diff --git a/test/Utilities/VariableTemplates/test_complex_models.jl b/test/Utilities/VariableTemplates/test_complex_models.jl index 84019df6679..f3226186fd8 100644 --- a/test/Utilities/VariableTemplates/test_complex_models.jl +++ b/test/Utilities/VariableTemplates/test_complex_models.jl @@ -4,7 +4,6 @@ using ClimateMachine.VariableTemplates using ClimateMachine.VariableTemplates: wrap_val @testset "Test complex models" begin - include("complex_models.jl") FT = Float32 @@ -92,8 +91,9 @@ using ClimateMachine.VariableTemplates: wrap_val end @test fn[j] === "scalar_model.x" - # test flattened_tup_chain - ftc = flattened_tup_chain(st) + # flattened_tup_chain - Retain arrays + + ftc = flattened_tup_chain(st, RetainArr()) j = 1 for i in 1:N @test ftc[j] === (:ntuple_model, i, :scalar_model, :x) @@ -124,7 +124,7 @@ using ClimateMachine.VariableTemplates: wrap_val # test that getproperty matches varsindex ntuple(N) do i i_ϕ = varsindex(st, wrap_val.(ftc[i])...) - ϕ = getproperty(v, ftc[i]) + ϕ = getproperty(v, wrap_val.(ftc[i])) @test all(parent(v)[i_ϕ] .≈ ϕ) end @@ -132,15 +132,15 @@ using ClimateMachine.VariableTemplates: wrap_val @unroll_map(N) do i @test v.scalar_model.x == getproperty(v, (:scalar_model, :x)) @test v.vector_model.x == getproperty(v, (:vector_model, :x)) - @test v.ntuple_model[i] == getproperty(v, (:ntuple_model, unval(i))) + @test v.ntuple_model[i] == getproperty(v, (:ntuple_model, i)) @test v.ntuple_model[i].scalar_model == - getproperty(v, (:ntuple_model, unval(i), :scalar_model)) + getproperty(v, (:ntuple_model, i, :scalar_model)) @test v.ntuple_model[i].scalar_model.x == - getproperty(v, (:ntuple_model, unval(i), :scalar_model, :x)) + getproperty(v, (:ntuple_model, i, :scalar_model, :x)) end # Test converting to flattened NamedTuple - fnt = flattened_named_tuple(v) + fnt = flattened_named_tuple(v, RetainArr()) @test fnt.ntuple_model_1_scalar_model_x == 1.0f0 @test fnt.ntuple_model_1_vector_model_x == Float32[2.0, 3.0, 4.0] @test fnt.ntuple_model_2_scalar_model_x == 5.0f0 @@ -154,4 +154,80 @@ using ClimateMachine.VariableTemplates: wrap_val @test fnt.vector_model_x == Float32[21.0, 22.0, 23.0] @test fnt.scalar_model_x == 24.0f0 + # flattened_tup_chain - Flatten arrays + + ftc = flattened_tup_chain(st, FlattenArr()) + j = 1 + for i in 1:N + @test ftc[j] === (:ntuple_model, i, :scalar_model, :x) + j += 1 + for k in 1:Nv + @test ftc[j] === (:ntuple_model, i, :vector_model, :x, k) + j += 1 + end + end + for i in 1:Nv + @test ftc[j] === (:vector_model, :x, i) + j += 1 + end + @test ftc[j] === (:scalar_model, :x) + + # test varsindex (flatten arrays) + ntuple(N) do i + i_val = Val(i) + i_sm = varsindex(st, :ntuple_model, i_val, :scalar_model, :x) + nt_offset = (Nv + 1) - 1 + + i_sm_correct = (i + nt_offset * (i - 1)):(i + nt_offset * (i - 1)) + @test i_sm == i_sm_correct + + for j in 1:Nv + i_vm = + varsindex(st, :ntuple_model, i_val, :vector_model, :x, Val(j)) + offset = 1 + i_start = i + nt_offset * (i - 1) + offset + i_vm_correct = i_start + j - 1 + @test i_vm == i_vm_correct:i_vm_correct + end + end + + # test that getproperty matches varsindex + ntuple(N) do i + i_ϕ = varsindex(st, wrap_val.(ftc[i])...) + ϕ = getproperty(v, wrap_val.(ftc[i])) + @test all(parent(v)[i_ϕ] .≈ ϕ) + end + + # test getproperty with tup-chain + for k in 1:Nv + @test v.vector_model.x[k] == getproperty(v, (:vector_model, :x, Val(k))) + end + + # Test converting to flattened NamedTuple + fnt = flattened_named_tuple(v, FlattenArr()) + @test fnt.ntuple_model_1_scalar_model_x == 1.0f0 + @test fnt.ntuple_model_1_vector_model_x_1 == 2.0 + @test fnt.ntuple_model_1_vector_model_x_2 == 3.0 + @test fnt.ntuple_model_1_vector_model_x_3 == 4.0 + @test fnt.ntuple_model_2_scalar_model_x == 5.0f0 + @test fnt.ntuple_model_2_vector_model_x_1 == 6.0 + @test fnt.ntuple_model_2_vector_model_x_2 == 7.0 + @test fnt.ntuple_model_2_vector_model_x_3 == 8.0 + @test fnt.ntuple_model_3_scalar_model_x == 9.0f0 + @test fnt.ntuple_model_3_vector_model_x_1 == 10.0 + @test fnt.ntuple_model_3_vector_model_x_2 == 11.0 + @test fnt.ntuple_model_3_vector_model_x_3 == 12.0 + @test fnt.ntuple_model_4_scalar_model_x == 13.0f0 + @test fnt.ntuple_model_4_vector_model_x_1 == 14.0 + @test fnt.ntuple_model_4_vector_model_x_2 == 15.0 + @test fnt.ntuple_model_4_vector_model_x_3 == 16.0 + @test fnt.ntuple_model_5_scalar_model_x == 17.0f0 + @test fnt.ntuple_model_5_vector_model_x_1 == 18.0 + @test fnt.ntuple_model_5_vector_model_x_2 == 19.0 + @test fnt.ntuple_model_5_vector_model_x_3 == 20.0 + @test fnt.vector_model_x_1 == 21.0 + @test fnt.vector_model_x_2 == 22.0 + @test fnt.vector_model_x_3 == 23.0 + @test fnt.scalar_model_x == 24.0f0 + end