Skip to content
This repository has been archived by the owner on Mar 1, 2023. It is now read-only.

Commit

Permalink
Extend flattened_tup_chain, update VarTemplates
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 15, 2021
1 parent 65a7e65 commit b0c5231
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 189 deletions.
2 changes: 2 additions & 0 deletions docs/src/APIs/Utilities/VariableTemplates.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ vuntuple
unroll_map
flattened_tup_chain
flattened_named_tuple
FlattenArr
RetainArr
varsize
varsindices
varsindex
Expand Down
35 changes: 10 additions & 25 deletions src/Utilities/SingleStackUtils/single_stack_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,31 +54,16 @@ function single_stack_diagnostics(

z = altitude(bl, aux)

# TODO: Use flatten_named_tuple = true, and flatten arrays
flatten_named_tuple = false
if flatten_named_tuple
nt = (;
z = altitude(bl, aux),
prog = flattened_named_tuple(prog), # Vars -> flattened NamedTuples
aux = flattened_named_tuple(aux), # Vars -> flattened NamedTuples
∇flux = flattened_named_tuple(∇flux), # Vars -> flattened NamedTuples
hyperdiff = flattened_named_tuple(hyperdiff), # Vars -> flattened NamedTuples
cache = cache,
)
# Flatten top level:
flattened_named_tuple(nt)
else
nt = (;
z = altitude(bl, aux),
prog = prog,
aux = aux,
∇flux = ∇flux,
hyperdiff = hyperdiff,
cache = cache,
)
nt
end
nt
nt = (;
z = altitude(bl, aux),
prog = flattened_named_tuple(prog), # Vars -> flattened NamedTuples
aux = flattened_named_tuple(aux), # Vars -> flattened NamedTuples
∇flux = flattened_named_tuple(∇flux), # Vars -> flattened NamedTuples
hyperdiff = flattened_named_tuple(hyperdiff), # Vars -> flattened NamedTuples
cache = cache,
)
# Flatten top level:
flattened_named_tuple(nt)
end for local_states in NodalStack(bl, grid; kwargs...)
]
end
132 changes: 66 additions & 66 deletions src/Utilities/VariableTemplates/VariableTemplates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,30 @@ wrap_val(i::Int) = Val(i)
# This means that users _must_ wrap `sym`
# in `Val`, which can be done with `wrap_val`
# above.
unval(::Val{i}) where {i} = i
Base.@propagate_inbounds function varsindex(
::Type{S},
sym,
sym::Symbol,
rest...,
) where {S <: Union{NamedTuple, Tuple}}
if sym isa Symbol
vi = varsindex(fieldtype(S, sym), rest...)
return varsindex(S, sym)[vi]
else
i = unval(sym)
et = eltype(S)
offset = (i - 1) * varsize(et)
vi = varsindex(et, rest...)
return (vi.start + offset):(vi.stop + offset)
end
vi = varsindex(fieldtype(S, sym), rest...)
return varsindex(S, sym)[vi]
end
Base.@propagate_inbounds function varsindex(
::Type{S},
::Val{i},
rest...,
) where {i, S <: Union{NamedTuple, Tuple}}
et = eltype(S)
offset = (i - 1) * varsize(et)
vi = varsindex(et, rest...)
return (vi.start + offset):(vi.stop + offset)
end

Base.@propagate_inbounds function varsindex(
::Type{S},
::Val{i},
) where {i, S <: SArray}
return i:i
end

"""
Expand Down Expand Up @@ -391,70 +399,62 @@ vuntuple(f::F, N::Int) where {F} = ntuple(i -> f(Val(i)), Val(N))
# Inside unroll_map expressions, all indexes `i`
# are wrapped in `Val`, so we must redirect
# these methods:
Base.getindex(t::Tuple, ::Val{i}) where {i} = Base.getindex(t, i)
Base.getindex(a::SArray, ::Val{i}) where {i} = Base.getindex(a, i)

# Somehow needed for GPU...
Base.@propagate_inbounds Base.getindex(v::AbstractVars, i::Int) =
Base.getindex(v, Val(i))
Base.@propagate_inbounds Base.getindex(t::Tuple, ::Val{i}) where {i} =
Base.getindex(t, i)
Base.@propagate_inbounds Base.getindex(a::SArray, ::Val{i}) where {i} =
Base.getindex(a, i)

Base.@propagate_inbounds function Base.getindex(
v::AbstractVars{NTuple{N, T}, A, offset},
v::Vars{NTuple{N, T}, A, offset},
::Val{i},
) where {N, T, A, offset, i}
# 1 <= i <= N
array = parent(v)
if v isa Vars
return Vars{T, A, offset + (i - 1) * varsize(T)}(array)
else
return Grad{T, A, offset + (i - 1) * varsize(T)}(array)
end
end

Base.@propagate_inbounds function Base.getproperty(
v::AbstractVars,
tup_chain::Tuple{S},
) where {S <: Symbol}
return Base.getproperty(v, tup_chain[1])
) where {N, T, A, offset, i} # 1 <= i <= N
return Vars{T, A, offset + (i - 1) * varsize(T)}(parent(v))
end

Base.@propagate_inbounds function Base.getindex(
v::AbstractVars,
tup_chain::Tuple{S},
) where {S <: Int}
return Base.getindex(v, Val(tup_chain[1]))
v::Grad{NTuple{N, T}, A, offset},
::Val{i},
) where {N, T, A, offset, i} # 1 <= i <= N
return Grad{T, A, offset + (i - 1) * varsize(T)}(parent(v))
end

Base.@propagate_inbounds function Base.getproperty(
"""
getpropertyorindex
An interchangeably and nested-friendly
`getproperty`/`getindex`.
"""
function getpropertyorindex end

# Redirect to Base getproperty/getindex:
Base.@propagate_inbounds getpropertyorindex(t::Tuple, ::Val{i}) where {i} =
Base.getindex(t, i)
Base.@propagate_inbounds getpropertyorindex(a::SArray, ::Val{i}) where {i} =
Base.getindex(a, i)
Base.@propagate_inbounds getpropertyorindex(nt::AbstractVars, s::Symbol) =
Base.getproperty(nt, s)
Base.@propagate_inbounds getpropertyorindex(
v::AbstractVars,
tup_chain::Tuple,
)
if tup_chain[1] isa Int
p = Base.getindex(v, Val(tup_chain[1]))
else
p = Base.getproperty(v, tup_chain[1])
end
if tup_chain[2] isa Int
return Base.getindex(p, tup_chain[2:end])
else
return Base.getproperty(p, tup_chain[2:end])
end
end
Base.@propagate_inbounds function Base.getindex(
::Val{i},
) where {i} = Base.getindex(v, Val(i))

# Only one element left:
Base.@propagate_inbounds getpropertyorindex(
v::AbstractVars,
tup_chain::Tuple,
)
if tup_chain[1] isa Int
p = Base.getindex(v, Val(tup_chain[1]))
else
p = Base.getproperty(v, tup_chain[1])
end
if tup_chain[2] isa Int
return Base.getindex(p, tup_chain[2:end])
else
return Base.getproperty(p, tup_chain[2:end])
end
end
t::Tuple{A},
) where {A} = getpropertyorindex(v, t[1])
Base.@propagate_inbounds getpropertyorindex(a::SArray, t::Tuple{A}) where {A} =
getpropertyorindex(a, t[1])

# Peel first element from tuple and recurse:
Base.@propagate_inbounds getpropertyorindex(v::AbstractVars, t::Tuple) =
getpropertyorindex(getpropertyorindex(v, t[1]), Tuple(t[2:end]))

# Redirect to getpropertyorindex:
Base.@propagate_inbounds Base.getproperty(v::AbstractVars, tup_chain::Tuple) =
getpropertyorindex(v, tup_chain)
Base.@propagate_inbounds Base.getindex(v::AbstractVars, tup_chain::Tuple) =
getpropertyorindex(v, tup_chain)

include("flattened_tup_chain.jl")

Expand Down
60 changes: 44 additions & 16 deletions src/Utilities/VariableTemplates/flattened_tup_chain.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
export flattened_tup_chain, flattened_named_tuple
export FlattenArr, RetainArr # should this be exported?

flattened_tup_chain(::Type{NamedTuple{(), Tuple{}}}; prefix = (Symbol(),)) = ()
flattened_tup_chain(::Type{T}; prefix = (Symbol(),)) where {T <: Real} =
(prefix,)
flattened_tup_chain(::Type{T}; prefix = (Symbol(),)) where {T <: SArray} =
(prefix,)
abstract type FlattenType end
struct FlattenArr <: FlattenType end
struct RetainArr <: FlattenType end

flattened_tup_chain(
::Type{NamedTuple{(), Tuple{}}},
::FlattenType = FlattenArr();
prefix = (Symbol(),),
) = ()
flattened_tup_chain(
::Type{T},
::FlattenType;
prefix = (Symbol(),),
) where {T <: Real} = (prefix,)
flattened_tup_chain(
::Type{T};
::Type{T},
::RetainArr;
prefix = (Symbol(),),
) where {T <: SArray} = (prefix,)
flattened_tup_chain(
::Type{T},
::FlattenArr;
prefix = (Symbol(),),
) where {T <: SArray} = ntuple(i -> (prefix..., i), length(T))

flattened_tup_chain(
::Type{T},
::FlattenType;
prefix = (Symbol(),),
) where {T <: SHermitianCompact} = (prefix,)
flattened_tup_chain(::Type{T}; prefix = (Symbol(),)) where {T} = (prefix,)
flattened_tup_chain(::Type{T}, ::FlattenType; prefix = (Symbol(),)) where {T} =
(prefix,)

"""
flattened_tup_chain(::Type{T}) where {T <: Union{NamedTuple,NTuple}}
Expand All @@ -19,7 +42,8 @@ and integers for every combination of
each field in the `Vars` array.
"""
function flattened_tup_chain(
::Type{T};
::Type{T},
ft::FlattenType = FlattenArr();
prefix = (Symbol(),),
) where {T <: Union{NamedTuple, NTuple}}
map(1:fieldcount(T)) do i
Expand All @@ -28,16 +52,20 @@ function flattened_tup_chain(
sname = name isa Int ? name : Symbol(name)
flattened_tup_chain(
Ti,
ft;
prefix = prefix == (Symbol(),) ? (sname,) : (prefix..., sname),
)
end |>
Iterators.flatten |>
collect
end
flattened_tup_chain(::AbstractVars{S}) where {S} = flattened_tup_chain(S)
flattened_tup_chain(
::AbstractVars{S},
ft::FlattenType = FlattenArr(),
) where {S} = flattened_tup_chain(S, ft)

"""
flattened_named_tuple(v::AbstractVars)
flattened_named_tuple(v::AbstractVars, ::FlattenType)
A flattened NamedTuple, given a `Vars` instance.
Expand All @@ -60,16 +88,16 @@ fnt = flattened_named_tuple(nt);
"""
function flattened_named_tuple end

function flattened_named_tuple(v::AbstractVars)
ftc = flattened_tup_chain(v)
function flattened_named_tuple(v::AbstractVars, ft::FlattenType = FlattenArr())
ftc = flattened_tup_chain(v, ft)
keys_ = Symbol.(join.(ftc, :_))
vals = map(x -> getproperty(v, x), ftc)
vals = map(x -> getproperty(v, wrap_val.(x)), ftc)
return (; zip(keys_, vals)...)
end
flattened_named_tuple(v::Nothing) = NamedTuple()
flattened_named_tuple(v::Nothing, ft::FlattenType = FlattenArr()) = NamedTuple()

function flattened_named_tuple(nt::NamedTuple)
ftc = flattened_tup_chain(typeof(nt))
function flattened_named_tuple(nt::NamedTuple, ft::FlattenType = FlattenArr())
ftc = flattened_tup_chain(typeof(nt), ft)
keys_ = Symbol.(join.(ftc, :_))
vals = flattened_nt_vals(nt)
return (; zip(keys_, vals)...)
Expand Down
2 changes: 1 addition & 1 deletion test/Atmos/EDMF/closures/entr_detr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function entr_detr(
env,
buoy,
) where {FT}
EΔ_up = ntuple(n_updrafts(bl.turbconv)) do i
EΔ_up = vuntuple(n_updrafts(bl.turbconv)) do i
entr_detr(bl, bl.turbconv.entr_detr, state, aux, ts_up, ts_en, env, buoy, i)
end
E_dyn, Δ_dyn, E_trb = ntuple(i -> map(x -> x[i], EΔ_up), 3)
Expand Down
2 changes: 1 addition & 1 deletion test/Atmos/EDMF/closures/pressure.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#### Pressure model kernels

function perturbation_pressure(bl::AtmosModel{FT}, args, env, buoy) where {FT}
dpdz = ntuple(n_updrafts(bl.turbconv)) do i
dpdz = vuntuple(n_updrafts(bl.turbconv)) do i
perturbation_pressure(bl, bl.turbconv.pressure, args, env, buoy, i)
end
return dpdz
Expand Down
20 changes: 18 additions & 2 deletions test/Atmos/EDMF/compute_mse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ PyCLES_output_dataset = ArtifactWrapper(
],
)
PyCLES_output_dataset_path = get_data_folder(PyCLES_output_dataset)
data_files = Dict()
data_files[:Bomex] = Dataset(joinpath(PyCLES_output_dataset_path, "Bomex.nc"), "r")
# data_files[:Rico] = Dataset(joinpath(PyCLES_output_dataset_path, "Rico.nc"), "r")
# data_files[:Gabls] = Dataset(joinpath(PyCLES_output_dataset_path, "Gabls.nc"), "r")
# data_files[:DYCOMS_RF01] = Dataset(joinpath(PyCLES_output_dataset_path, "DYCOMS_RF01.nc"), "r")
Expand Down Expand Up @@ -222,3 +220,21 @@ function compute_mse(

return mse
end

sufficient_mse(computed_mse, best_mse) = computed_mse <= best_mse + sqrt(eps())

function test_mse(computed_mse, best_mse, key)
mse_not_regressed = sufficient_mse(computed_mse[key], best_mse[key])
@test mse_not_regressed
mse_not_regressed || @show key
end

function dons(diag_vs_z)
return Dict(map(keys(first(diag_vs_z))) do k
string(k) => [getproperty(ca, k) for ca in diag_vs_z]
end)
end

get_dons_arr(diag_arr) = [dons(diag_vs_z) for diag_vs_z in diag_arr]

dons_arr = get_dons_arr(diag_arr)
Loading

0 comments on commit b0c5231

Please sign in to comment.