Skip to content

Commit

Permalink
foo5
Browse files Browse the repository at this point in the history
  • Loading branch information
谢萧涯 authored and 谢萧涯 committed Sep 10, 2024
1 parent 77757d7 commit 1b4266b
Showing 1 changed file with 85 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,7 @@ function non_orographic_gravity_wave_forcing(
#We use StaticBitVector here because the unrolled_reduce function in Julia can cause memory allocation issues when the mask has more than 32 elements。
#StaticBitVector stores 8 boolean values in a UInt8, allowing efficient storage for up to 256 gravity wave break data.
level_end = Spaces.nlevels(axes(ᶜρ))
c = gw_c


# loop over all wave lengths
for ink in 1:gw_nk

Expand Down Expand Up @@ -428,12 +427,78 @@ function non_orographic_gravity_wave_forcing(
)

# Accumulate zonal wave forcing in every column
Operators.column_accumulate!(
waveforcing_column_accumulate!(u_waveforcing,mask_u,input_u,gw_c,ink,gw_nk,level_end,gw_c0,gw_ncval)
# Accumulate meridional wave forcing in every column
waveforcing_column_accumulate!(v_waveforcing,mask_v,input_v,gw_c,ink,gw_nk,level_end,gw_c0,gw_ncval)

#extract the momentum flux outside the model top.
u_waveforcing_top = p.scratch.temp_field_level
copyto!(
Fields.field_values(u_waveforcing_top),
Fields.field_values(
Fields.level(
u_waveforcing,
Spaces.nlevels(axes(u_waveforcing)),
),
),
)
fill!(
Fields.level(u_waveforcing, Spaces.nlevels(axes(u_waveforcing))),
0,
)

v_waveforcing_top = p.scratch.temp_field_level
copyto!(
Fields.field_values(v_waveforcing_top),
Fields.field_values(
Fields.level(
v_waveforcing,
Spaces.nlevels(axes(v_waveforcing)),
),
),
)
fill!(
Fields.level(v_waveforcing, Spaces.nlevels(axes(v_waveforcing))),
0,
)

#interpolate the waveforcing from center to face
gw_average!(u_waveforcing, p.scratch.ᶜtemp_scalar)
gw_average!(v_waveforcing, p.scratch.ᶜtemp_scalar)

#The momentum flux outside the model top will be evenly deposited onto the levels between the damp level and the model top.
@. u_waveforcing = gw_deposit(
u_waveforcing_top,
u_waveforcing,
input_u;
init = (FT(0.0), mask_u, FT(NaN), ntuple(i -> FT(NaN), Val(nc))),
damp_level,
ᶜlevel,
level_end,
)
@. v_waveforcing = gw_deposit(
v_waveforcing_top,
v_waveforcing,
damp_level,
ᶜlevel,
level_end,
)

#update gravity wave forcing
@. uforcing = uforcing + u_waveforcing
@. vforcing = vforcing + v_waveforcing

end
return nothing
end

# calculate the intermittency factor eps -> assuming constant Δc.
function waveforcing_column_accumulate!(waveforcing,mask,input,c,ink,nk,level_end,c0,gw_ncval::Val{nc}) where {nc}
FT=eltype(waveforcing)
Operators.column_accumulate!(
waveforcing,
input;
init = (FT(0.0), mask, FT(NaN), ntuple(i -> FT(NaN), Val(nc))),
transform = first,
) do (wave_forcing_u, mask_u, Bsum_or_NaN_u, B0_or_NaNs_u),
) do (wave_forcing, mask, Bsum_or_NaN, B0_or_NaNs),
(
u_kp1,
u_source,
Expand Down Expand Up @@ -462,8 +527,8 @@ function non_orographic_gravity_wave_forcing(
alp2 = FT1(0.25) / (Hb * Hb)
ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection)

B0_u, Bsum_u = if level == 1
mask_u = StaticBitVector{nc}(_ -> true)
B0, Bsum = if level == 1
mask = StaticBitVector{nc}(_ -> true)
B1 = ntuple(
n ->
sign((c[n] - u_source)) * (
Expand All @@ -473,7 +538,7 @@ function non_orographic_gravity_wave_forcing(
(
c[n] * flag +
(c[n] - u_source) * (1 - flag) -
gw_c0
c0
) / cw
)^2,
) +
Expand All @@ -483,7 +548,7 @@ function non_orographic_gravity_wave_forcing(
(
c[n] * flag +
(c[n] - u_source) * (1 - flag) -
gw_c0
c0
) / cn
)^2,
)
Expand All @@ -493,16 +558,16 @@ function non_orographic_gravity_wave_forcing(
Bsum1 = sum(abs, B1)
B1, Bsum1
else
B0_or_NaNs_u, Bsum_or_NaN_u
B0_or_NaNs, Bsum_or_NaN
end
#calculate momentum flux carried by gravity waves with different phase speeds.

if level >= source_level - 1
fm = FT1(0.0)
#We use the unrolled_reduce function here because it performs better for parallel execution on the GPU. Additionally, it also helps speed up the setindex function
(mask_u, fm) = unrolled_reduce(
(mask, fm) = unrolled_reduce(
Val(nc),
(mask_u, fm),
(mask, fm),
) do (mask, fm), (n)
if (mask[n]) == true
c_hat = c[n] - u_kp1 # c0mu
Expand All @@ -524,7 +589,7 @@ function non_orographic_gravity_wave_forcing(
# momentum flux is conserved
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0_u[n]
fm = fm + B0[n]
end
else
# if wave is not reflected at this level, determine if it is
Expand All @@ -535,11 +600,11 @@ function non_orographic_gravity_wave_forcing(
# this level.
# set mask=0.0 to remove phase speed band c[n] from the set of active
# waves moving upwards to the next level.
Foc = B0_u[n] / (c_hat)^3 - fac
Foc = B0[n] / (c_hat)^3 - fac
if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0)
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0_u[n]
fm = fm + B0[n]
end
end
end
Expand All @@ -553,220 +618,20 @@ function non_orographic_gravity_wave_forcing(
# compute the gravity wave momentum flux forcing
# obtained across the entire wave spectrum at this level.
eps =
calc_intermitency(ρ_source, source_ampl, gw_nk, FT1(Bsum_u))
calc_intermitency(ρ_source, source_ampl, nk, FT1(Bsum))
#calculate intermittency factor
if level >= source_level
rbh = sqrt(ρ_k * ρ_kp1)
wave_forcing_u =
wave_forcing =
(ρ_source / rbh) * FT1(fm) * eps / (z_kp1 - z_k)
else
wave_forcing_u = FT1(0.0)
wave_forcing = FT1(0.0)
end
end
return (wave_forcing_u, mask_u, Bsum_u, B0_u)
return (wave_forcing, mask, Bsum, B0)

end

# Accumulate meridional wave forcing in every column
Operators.column_accumulate!(
v_waveforcing,
input_v;
init = (FT(0.0), mask_u, FT(NaN), ntuple(i -> FT(NaN), Val(nc))),
transform = first,
) do (wave_forcing_v, mask_v, Bsum_or_NaN_v, B0_or_NaNs_v),
(
v_kp1,
v_source,
bf_kp1,
ρ_k,
ρ_kp1,
ρ_source,
z_kp1,
z_k,
source_level,
Bw,
Bn,
cw,
cn,
flag,
level,
source_ampl,
)

FT1 = typeof(v_kp1)
kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) #wave number of gravity waves
k2 = kwv * kwv

fac = FT1(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1
Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height
alp2 = FT1(0.25) / (Hb * Hb)
ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection)

B0_v, Bsum_v = if level == 1
mask_v = StaticBitVector{nc}(_ -> true)
B1 = ntuple(
n ->
sign((c[n] - v_source)) * (
Bw * exp(
-log(2.0f0) *
(
(
c[n] * flag +
(c[n] - v_source) * (1 - flag) -
gw_c0
) / cw
)^2,
) +
Bn * exp(
-log(2.0f0) *
(
(
c[n] * flag +
(c[n] - v_source) * (1 - flag) -
gw_c0
) / cn
)^2,
)
),
Val(nc),
)
Bsum1 = sum(abs, B1)
B1, Bsum1
else
B0_or_NaNs_v, Bsum_or_NaN_v
end
#calculate momentum flux carried by gravity waves with different phase speeds.

if level >= source_level - 1
fm = FT1(0.0)
(mask_v, fm) = unrolled_reduce(
Val(nc),
(mask_v, fm),
) do (mask, fm), (n)
if (mask[n]) == true
c_hat = c[n] - v_kp1 # c0mu
# f phase speed matches the wind speed, remove c(n) from the set of propagating waves.
if c_hat == 0.0
mask = Base.setindex(mask, false, n)
else
c_hat0 = c[n] - v_source
# define the criterion which determines if wave is reflected at this level (test).
test = abs(c_hat) * kwv - ω_r
if test >= 0.0
# wave has undergone total internal reflection. remove it from the propagating set.
mask = Base.setindex(mask, false, n)
else
if level == level_end
# this is added in MiMA implementation:
# all momentum flux that escapes across the model top
# is deposited to the extra level being added so that
# momentum flux is conserved
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0_v[n]
end
else
# if wave is not reflected at this level, determine if it is
# breaking at this level (Foc >= 0), or if wave speed relative to
# windspeed has changed sign from its value at the source level
# (c_hat0[n] * c_hat <= 0). if it is above the source level and is
# breaking, then add its momentum flux to the accumulated sum at
# this level.
# set mask=0.0 to remove phase speed band c[n] from the set of active
# waves moving upwards to the next level.
Foc = B0_v[n] / (c_hat)^3 - fac
if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0)
mask = Base.setindex(mask, false, n)
if level >= source_level
fm = fm + B0_v[n]
end
end
end
end # (test >= 0.0)

end #(c_hat == 0.0)
end # mask = 0
return (mask, fm)
end

# compute the gravity wave momentum flux forcing
# obtained across the entire wave spectrum at this level.
eps =
calc_intermitency(ρ_source, source_ampl, gw_nk, FT1(Bsum_v))
if level >= source_level
rbh = sqrt(ρ_k * ρ_kp1)
wave_forcing_v =
(ρ_source / rbh) * FT1(fm) * eps / (z_kp1 - z_k)
else
wave_forcing_v = FT1(0.0)
end

end
return (wave_forcing_v, mask_v, Bsum_v, B0_v)

end


u_waveforcing_top = p.scratch.temp_field_level
copyto!(
Fields.field_values(u_waveforcing_top),
Fields.field_values(
Fields.level(
u_waveforcing,
Spaces.nlevels(axes(u_waveforcing)),
),
),
)
fill!(
Fields.level(u_waveforcing, Spaces.nlevels(axes(u_waveforcing))),
0,
)
#extract the momentum flux outside the model top.

v_waveforcing_top = p.scratch.temp_field_level
copyto!(
Fields.field_values(v_waveforcing_top),
Fields.field_values(
Fields.level(
v_waveforcing,
Spaces.nlevels(axes(v_waveforcing)),
),
),
)
fill!(
Fields.level(v_waveforcing, Spaces.nlevels(axes(v_waveforcing))),
0,
)
#extract the momentum flux outside the model top.

gw_average!(u_waveforcing, p.scratch.ᶜtemp_scalar)
gw_average!(v_waveforcing, p.scratch.ᶜtemp_scalar)
#interpolate the waveforcing from center to face

@. u_waveforcing = gw_deposit(
u_waveforcing_top,
u_waveforcing,
damp_level,
ᶜlevel,
level_end,
)
@. v_waveforcing = gw_deposit(
v_waveforcing_top,
v_waveforcing,
damp_level,
ᶜlevel,
level_end,
)
#The momentum flux outside the model top will be evenly deposited onto the levels between the damp level and the model top.

@. uforcing = uforcing + u_waveforcing
@. vforcing = vforcing + v_waveforcing
#update gravity wave forcing

end
return nothing
end
# calculate the intermittency factor eps -> assuming constant Δc.

function calc_intermitency(ρ_source_level, source_ampl, nk, Bsum)
return (source_ampl / ρ_source_level / nk) / Bsum
Expand Down

0 comments on commit 1b4266b

Please sign in to comment.