From f977bfc2080ec0b2fe1fae07397e1f7694d1e788 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Sun, 1 Sep 2024 16:49:27 -0700 Subject: [PATCH] Make GWP GPU-compatible --- .buildkite/pipeline.yml | 12 + Project.toml | 2 + .../sphere_nonorographic_gravity_wave.yml | 3 + examples/Manifest.toml | 15 +- perf/Manifest.toml | 15 +- regression_tests/ref_counter.jl | 5 +- .../non_orographic_gravity_wave.jl | 492 ++++++++---------- src/solver/types.jl | 2 +- .../nogw_test_3d.jl | 5 +- .../nogw_test_mima.jl | 8 +- .../nogw_test_single_column.jl | 64 ++- 11 files changed, 313 insertions(+), 310 deletions(-) create mode 100644 config/model_configs/sphere_nonorographic_gravity_wave.yml diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index e0d0ceac9b5..fae7612d9ad 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -761,6 +761,18 @@ steps: - group: "GPU" steps: + - label: "GPU:Gravity waves" + command: > + julia --color=yes --project=examples examples/hybrid/driver.jl + --config_file $CONFIG_PATH/sphere_nonorographic_gravity_wave.yml + --job_id sphere_nonorographic_gravity_wave + artifact_paths: "sphere_nonorographic_gravity_wave/*" + env: + CLIMACOMMS_DEVICE: "CUDA" + agents: + slurm_gpus: 1 + slurm_mem: 16G + - label: "GPU: baroclinic wave" key: "sphere_baroclinic_wave_rhoe_gpu" command: > diff --git a/Project.toml b/Project.toml index adf90ef5e9d..599c5ca30c0 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SurfaceFluxes = "49b00bb7-8bd4-4f2b-b78c-51cd0450215f" Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c" +UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [compat] @@ -63,5 +64,6 @@ StaticArrays = "1.7" Statistics = "1" SurfaceFluxes = "0.11, 0.12" Thermodynamics = "0.12.4" +UnrolledUtilities = "0.1.4" YAML = "0.4" julia = "1.9" diff --git a/config/model_configs/sphere_nonorographic_gravity_wave.yml b/config/model_configs/sphere_nonorographic_gravity_wave.yml new file mode 100644 index 00000000000..89b93e04080 --- /dev/null +++ b/config/model_configs/sphere_nonorographic_gravity_wave.yml @@ -0,0 +1,3 @@ +t_end: "1500secs" +dt: "400secs" +non_orographic_gravity_wave: true diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 987e8e04e2c..02378fd8754 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -315,7 +315,7 @@ version = "0.5.7" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" [[deps.ClimaAtmos]] -deps = ["Adapt", "ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "DiffEqBase", "FastGaussQuadrature", "Insolation", "Interpolations", "LazyArtifacts", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "RRTMGP", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "YAML"] +deps = ["Adapt", "ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "DiffEqBase", "FastGaussQuadrature", "Insolation", "Interpolations", "LazyArtifacts", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "RRTMGP", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "UnrolledUtilities", "YAML"] path = ".." uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717" version = "0.27.5" @@ -332,9 +332,9 @@ weakdeps = ["CUDA", "MPI"] [[deps.ClimaCore]] deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"] -git-tree-sha1 = "ffd27299555f968f96e348060146228c6259bb4b" +git-tree-sha1 = "806e8490ff1aa664ca579544d798f8addfa1b07d" uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.14.14" +version = "0.14.15" weakdeps = ["CUDA", "Krylov"] [deps.ClimaCore.extensions] @@ -2434,6 +2434,15 @@ git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" version = "0.1.5" +[[deps.UnrolledUtilities]] +git-tree-sha1 = "d0b2aa2d71fa2f4494cb3cf69719a6807ea0df40" +uuid = "0fe1646c-419e-43be-ac14-22321958931b" +version = "0.1.4" +weakdeps = ["StaticArrays"] + + [deps.UnrolledUtilities.extensions] + UnrolledUtilitiesStaticArraysExt = "StaticArrays" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/perf/Manifest.toml b/perf/Manifest.toml index ea212e3255d..04808383b73 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -326,7 +326,7 @@ version = "0.5.6" GeoMakie = "db073c08-6b98-4ee5-b6a4-5efafb3259c6" [[deps.ClimaAtmos]] -deps = ["Adapt", "ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "DiffEqBase", "FastGaussQuadrature", "Insolation", "Interpolations", "LazyArtifacts", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "RRTMGP", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "YAML"] +deps = ["Adapt", "ArgParse", "ArtifactWrappers", "Artifacts", "AtmosphericProfilesLibrary", "ClimaComms", "ClimaCore", "ClimaDiagnostics", "ClimaParams", "ClimaTimeSteppers", "ClimaUtilities", "CloudMicrophysics", "Dates", "DiffEqBase", "FastGaussQuadrature", "Insolation", "Interpolations", "LazyArtifacts", "LinearAlgebra", "Logging", "NCDatasets", "NVTX", "RRTMGP", "SciMLBase", "StaticArrays", "Statistics", "SurfaceFluxes", "Thermodynamics", "UnrolledUtilities", "YAML"] path = ".." uuid = "b2c96348-7fb7-4fe0-8da9-78d88439e717" version = "0.27.5" @@ -343,9 +343,9 @@ weakdeps = ["CUDA", "MPI"] [[deps.ClimaCore]] deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"] -git-tree-sha1 = "ffd27299555f968f96e348060146228c6259bb4b" +git-tree-sha1 = "806e8490ff1aa664ca579544d798f8addfa1b07d" uuid = "d414da3d-4745-48bb-8d80-42e94e092884" -version = "0.14.14" +version = "0.14.15" weakdeps = ["CUDA", "Krylov"] [deps.ClimaCore.extensions] @@ -2563,6 +2563,15 @@ git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" version = "0.1.5" +[[deps.UnrolledUtilities]] +git-tree-sha1 = "d0b2aa2d71fa2f4494cb3cf69719a6807ea0df40" +uuid = "0fe1646c-419e-43be-ac14-22321958931b" +version = "0.1.4" +weakdeps = ["StaticArrays"] + + [deps.UnrolledUtilities.extensions] + UnrolledUtilitiesStaticArraysExt = "StaticArrays" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/regression_tests/ref_counter.jl b/regression_tests/ref_counter.jl index 0dc76f15890..efc6e6547bb 100644 --- a/regression_tests/ref_counter.jl +++ b/regression_tests/ref_counter.jl @@ -1,6 +1,9 @@ -176 +177 #= +177: +- change numerics of non-orographic gravity waves + 176: - Switch to hyperbolic tangent stretching diff --git a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl index 14ee475fd87..b81326b5603 100644 --- a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl +++ b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl @@ -1,11 +1,11 @@ ##### ##### Non-orographic gravity wave parameterization ##### - +using UnrolledUtilities import ClimaCore.Spaces as Spaces import ClimaCore.Fields as Fields import ClimaCore.Geometry as Geometry -import ClimaCore.Operators as Operators +import ClimaCore.Operators as Operator non_orographic_gravity_wave_cache(Y, atmos::AtmosModel) = non_orographic_gravity_wave_cache( @@ -27,7 +27,7 @@ function non_orographic_gravity_wave_cache( (; source_height, Bw, Bn, Bt_0, dc, cmax, c0, nk, cw, cn) = gw nc = Int(floor(FT(2 * cmax / dc + 1))) - c = [FT((n - 1) * dc - cmax) for n in 1:nc] + c = ntuple(n -> FT((n - 1) * dc - cmax), Val(nc)) source_ρ_z_u_v_level = similar(Fields.level(Y.c.ρ, 1), Tuple{FT, FT, FT, FT, FT}) ᶜlevel = similar(Y.c.ρ, FT) @@ -40,7 +40,6 @@ function non_orographic_gravity_wave_cache( gw_source_ampl = Bt_0 .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), gw_Bw = Bw .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), gw_Bn = Bn .* ones(FT, axes(Fields.level(Y.c.ρ, 1))), - gw_B0 = similar(c), gw_c = c, gw_dc = dc, gw_cmax = cmax, @@ -74,7 +73,7 @@ function non_orographic_gravity_wave_cache( (; ϕ0_s, ϕ0_n, dϕ_n, dϕ_s, dc, cmax, c0, nk, cw, cw_tropics, cn) = gw nc = Int(floor(FT(2 * cmax / dc + 1))) - c = [FT((n - 1) * dc - cmax) for n in 1:nc] + c = ntuple(n -> FT((n - 1) * dc - cmax), Val(nc)) ᶜlocal_geometry = Fields.local_geometry_field(Fields.level(Y.c, 1)) lat = ᶜlocal_geometry.coordinates.lat @@ -121,7 +120,6 @@ function non_orographic_gravity_wave_cache( gw_source_ampl = source_ampl, gw_Bw = gw_Bw, gw_Bn = gw_Bn, - gw_B0 = similar(c), gw_c = c, gw_cw = gw_cw, gw_cn = gw_cn, @@ -169,13 +167,6 @@ function non_orographic_gravity_wave_tendency!( ) = p.non_orographic_gravity_wave (; model_config) = p.atmos - if model_config isa SingleColumnModel - (; gw_source_height, source_ρ_z_u_v_level) = - p.non_orographic_gravity_wave - elseif model_config isa SphericalModel - (; gw_source_pressure, gw_damp_pressure, source_p_ρ_z_u_v_level) = - p.non_orographic_gravity_wave - end ᶜρ = Y.c.ρ ᶜz = Fields.coordinate_field(Y.c).z FT = Spaces.undertype(axes(Y.c)) @@ -202,6 +193,8 @@ function non_orographic_gravity_wave_tendency!( if model_config isa SingleColumnModel # source level: the index of the level that is closest to the source height + (; gw_source_height, source_ρ_z_u_v_level) = + p.non_orographic_gravity_wave input = Base.Broadcast.broadcasted(tuple, ᶜρ, ᶜz, ᶜu, ᶜv, ᶜlevel) Operators.column_reduce!( @@ -219,11 +212,14 @@ function non_orographic_gravity_wave_tendency!( ᶜu_source = source_ρ_z_u_v_level.:3 ᶜv_source = source_ρ_z_u_v_level.:4 source_level = source_ρ_z_u_v_level.:5 + # get the ρ,u,v value on the source level fill!(damp_level, Spaces.nlevels(axes(ᶜz))) elseif model_config isa SphericalModel (; ᶜp) = p.precomputed + (; gw_source_pressure, gw_damp_pressure, source_p_ρ_z_u_v_level) = + p.non_orographic_gravity_wave # source level: the index of the highest level whose pressure is higher than source pressure input = Base.Broadcast.broadcasted(tuple, ᶜp, ᶜρ, ᶜz, ᶜu, ᶜv, ᶜlevel) @@ -243,6 +239,7 @@ function non_orographic_gravity_wave_tendency!( ᶜu_source = source_p_ρ_z_u_v_level.:4 ᶜv_source = source_p_ρ_z_u_v_level.:5 source_level = source_p_ρ_z_u_v_level.:6 + # get the ρ,u,v value on the source level # damp level: the index of the lowest level whose pressure is lower than the damp pressure @@ -311,6 +308,7 @@ function non_orographic_gravity_wave_forcing( v_waveforcing, p, ) where {nc} + # unpack parameters (; gw_source_ampl, gw_Bw, @@ -322,13 +320,17 @@ function non_orographic_gravity_wave_forcing( gw_c0, gw_nk, ) = p.non_orographic_gravity_wave + + # Temporary scratch fields for shifting levels up ᶜρ_p1 = p.scratch.ᶜtemp_scalar ᶜz_p1 = p.scratch.ᶜtemp_scalar_2 ᶜu_p1 = p.scratch.ᶜtemp_scalar_3 ᶜv_p1 = p.scratch.ᶜtemp_scalar_4 ᶜbf_p1 = p.scratch.ᶜtemp_scalar_5 - nci = get_nc(gw_ncval) - FT = eltype(ᶜρ) + + FT = eltype(ᶜρ) # Define the floating point type + + # Using interpolate operator, generate the field of ρ,u,v,z with on level shifted up ρ_endlevel = Fields.level(ᶜρ, Spaces.nlevels(axes(ᶜρ))) ρ_endlevel_m1 = Fields.level(ᶜρ, Spaces.nlevels(axes(ᶜρ)) - 1) Boundary_value = Fields.Field( @@ -368,12 +370,16 @@ function non_orographic_gravity_wave_forcing( ) field_shiftlevel_up!(ᶜz, ᶜz_p1, Boundary_value) - mask = BitVector(ones(nc)) + mask_u = StaticBitVector{nc}(_ -> true) + mask_v = StaticBitVector{nc}(_ -> true) + #We use StaticBitVector here because the unrolled_reduce function in Julia can cause memory allocation issues when the mask has more than 32 elements。 + #StaticBitVector stores 8 boolean values in a UInt8, allowing efficient storage for up to 256 gravity wave break data. level_end = Spaces.nlevels(axes(ᶜρ)) - B1 = ntuple(i -> 0.0, Val(nc)) + # loop over all wave lengths for ink in 1:gw_nk + # Collect all required fields in a broadcasted object input_u = Base.Broadcast.broadcasted( tuple, ᶜu_p1, @@ -389,13 +395,11 @@ function non_orographic_gravity_wave_forcing( gw_Bn, gw_cw, gw_cn, - gw_c, gw_flag, ᶜlevel, gw_source_ampl, ) - input_v = Base.Broadcast.broadcasted( tuple, ᶜv_p1, @@ -411,270 +415,37 @@ function non_orographic_gravity_wave_forcing( gw_Bn, gw_cw, gw_cn, - gw_c, gw_flag, ᶜlevel, gw_source_ampl, ) - Operators.column_accumulate!( + # Accumulate zonal wave forcing in every column + waveforcing_column_accumulate!( u_waveforcing, - input_u; - init = (FT(0.0), mask, 0.0, B1), - transform = first, - ) do (wave_forcing, mask, Bsum, B0), - ( - u_kp1, - u_source, - bf_kp1, - ρ_k, - ρ_kp1, - ρ_source, - z_kp1, - z_k, - source_level, - Bw, - Bn, - cw, - cn, - c, - flag, - level, - source_ampl, + mask_u, + input_u, + gw_c, + gw_c0, + gw_nk, + ink, + level_end, + gw_ncval, ) - if level >= (source_level - 1) - FT1 = typeof(u_kp1) - kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) - k2 = kwv * kwv - fac = FT1(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 - Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height - alp2 = FT1(0.25) / (Hb * Hb) - ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) - fm = 0.0 - if level == (source_level - 1) - mask .= 1 - Bsum = 0.0 - B0 = ntuple( - n -> - sign(c[n] - u_source) * ( - Bw * exp( - -log(2.0) * - ( - ( - c[n] * flag + - (c[n] - u_source) * (1 - flag) - - gw_c0 - ) / cw - )^2, - ) + - Bn * exp( - -log(2.0) * - ( - ( - c[n] * flag + - (c[n] - u_source) * (1 - flag) - - gw_c0 - ) / cn - )^2, - ) - ), - Val(nc), - ) - Bsum = sum(abs.(B0)) - end - for n in 1:nci - # check only those waves which are still propagating, i.e., mask = 1.0 - if (mask[n]) == 1 - c_hat = c[n] - u_kp1 # c0mu - # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. - if c_hat == 0.0 - mask[n] = 0 - else - c_hat0 = c[n] - u_source - # define the criterion which determines if wave is reflected at this level (test). - test = abs(c_hat) * kwv - ω_r - if test >= 0.0 - # wave has undergone total internal reflection. remove it from the propagating set. - mask[n] = 0 - else - if level == level_end - # this is added in MiMA implementation: - # all momentum flux that escapes across the model top - # is deposited to the extra level being added so that - # momentum flux is conserved - mask[n] = 0 - if level >= source_level - fm = fm + B0[n] - end - else - # if wave is not reflected at this level, determine if it is - # breaking at this level (Foc >= 0), or if wave speed relative to - # windspeed has changed sign from its value at the source level - # (c_hat0[n] * c_hat <= 0). if it is above the source level and is - # breaking, then add its momentum flux to the accumulated sum at - # this level. - # set mask=0.0 to remove phase speed band c[n] from the set of active - # waves moving upwards to the next level. - Foc = B0[n] / FT1((c_hat)^3) - fac - if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) - mask[n] = 0 - if level >= source_level - fm = fm + B0[n] - end - end - end - end # (test >= 0.0) - end #(c_hat == 0.0) - end # mask = 0 - - end # nc: phase speed loop - - # compute the gravity wave momentum flux forcing - # obtained across the entire wave spectrum at this level. - eps = calc_intermitency(ρ_source, source_ampl, gw_nk, FT1(Bsum)) - if level >= source_level - rbh = sqrt(ρ_k * ρ_kp1) - wave_forcing = - (ρ_source / rbh) * FT1(fm) * eps / (z_kp1 - z_k) - else - wave_forcing = FT1(0.0) - end - end - return (wave_forcing, mask, Bsum, B0) - - end - - - - Operators.column_accumulate!( + # Accumulate meridional wave forcing in every column + waveforcing_column_accumulate!( v_waveforcing, - input_v; - init = (FT(0.0), mask, 0.0, B1), - transform = first, - ) do (wave_forcing, mask, Bsum, B0), - ( - u_kp1, - u_source, - bf_kp1, - ρ_k, - ρ_kp1, - ρ_source, - z_kp1, - z_k, - source_level, - Bw, - Bn, - cw, - cn, - c, - flag, - level, - source_ampl, + mask_v, + input_v, + gw_c, + gw_c0, + gw_nk, + ink, + level_end, + gw_ncval, ) - if level >= (source_level - 1) - FT2 = typeof(u_kp1) - kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) - k2 = kwv * kwv - fac = FT2(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 - Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height - alp2 = FT2(0.25) / (Hb * Hb) - ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) - - fm = 0.0 - if level == (source_level - 1) - mask .= 1 - Bsum = 0.0 - B0 = ntuple( - n -> - sign((c[n] - u_source)) * ( - Bw * exp( - -log(2.0) * - ( - ( - c[n] * flag + - (c[n] - u_source) * (1 - flag) - - gw_c0 - ) / cw - )^2, - ) + - Bn * exp( - -log(2.0) * - ( - ( - c[n] * flag + - (c[n] - u_source) * (1 - flag) - - gw_c0 - ) / cn - )^2, - ) - ), - Val(nc), - ) - Bsum = sum(abs.(B0)) - end - for n in 1:nci - # check only those waves which are still propagating, i.e., mask = 1.0 - if (mask[n]) == 1 - c_hat = c[n] - u_kp1 # c0mu - # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. - if c_hat == 0.0 - mask[n] = 0 - else - c_hat0 = c[n] - u_source - # define the criterion which determines if wave is reflected at this level (test). - test = abs(c_hat) * kwv - ω_r - if test >= 0.0 - # wave has undergone total internal reflection. remove it from the propagating set. - mask[n] = 0 - else - if level == level_end - # this is added in MiMA implementation: - # all momentum flux that escapes across the model top - # is deposited to the extra level being added so that - # momentum flux is conserved - mask[n] = 0 - if level >= source_level - fm = fm + B0[n] - end - else - # if wave is not reflected at this level, determine if it is - # breaking at this level (Foc >= 0), or if wave speed relative to - # windspeed has changed sign from its value at the source level - # (c_hat0[n] * c_hat <= 0). if it is above the source level and is - # breaking, then add its momentum flux to the accumulated sum at - # this level. - # set mask=0.0 to remove phase speed band c[n] from the set of active - # waves moving upwards to the next level. - Foc = B0[n] / FT2((c_hat)^3) - fac - if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) - mask[n] = 0 - if level >= source_level - fm = fm + B0[n] - end - end - end - end # (test >= 0.0) - end #(c_hat == 0.0) - end # mask = 0 - - end # nc: phase speed loop - - # compute the gravity wave momentum flux forcing - # obtained across the entire wave spectrum at this level. - eps = calc_intermitency(ρ_source, source_ampl, gw_nk, FT2(Bsum)) - if level >= source_level - rbh = sqrt(ρ_k * ρ_kp1) - wave_forcing = - (ρ_source / rbh) * FT2(fm) * eps / (z_kp1 - z_k) - else - wave_forcing = FT2(0.0) - end - end - return (wave_forcing, mask, Bsum, B0) - - end - + #extract the momentum flux outside the model top. u_waveforcing_top = p.scratch.temp_field_level copyto!( Fields.field_values(u_waveforcing_top), @@ -690,9 +461,6 @@ function non_orographic_gravity_wave_forcing( 0, ) - # v_waveforcing_top = similar( - # Fields.level(v_waveforcing, Spaces.nlevels(axes(v_waveforcing))), - # ) v_waveforcing_top = p.scratch.temp_field_level copyto!( Fields.field_values(v_waveforcing_top), @@ -708,9 +476,11 @@ function non_orographic_gravity_wave_forcing( 0, ) + # interpolate the waveforcing from center to face gw_average!(u_waveforcing, p.scratch.ᶜtemp_scalar) gw_average!(v_waveforcing, p.scratch.ᶜtemp_scalar) + # The momentum flux outside the model top will be evenly deposited onto the levels between the damp level and the model top. @. u_waveforcing = gw_deposit( u_waveforcing_top, u_waveforcing, @@ -726,6 +496,7 @@ function non_orographic_gravity_wave_forcing( level_end, ) + # update gravity wave forcing @. uforcing = uforcing + u_waveforcing @. vforcing = vforcing + v_waveforcing @@ -733,17 +504,143 @@ function non_orographic_gravity_wave_forcing( return nothing end -# calculate the intermittency factor eps -> assuming constant Δc. +# Using column_accumulate function, calculate the gravity wave forcing at each point. +function waveforcing_column_accumulate!( + waveforcing, + mask, + input, + c, + c0, + nk, + ink, + level_end, + gw_ncval::Val{nc}, +) where {nc} + FT = eltype(waveforcing) + # Here we use column_accumulate function to pass the variable B0 and mask through different levels, and calculate waveforcing at each level. + Operators.column_accumulate!( + waveforcing, + input; + init = (FT(0.0), mask, FT(NaN), ntuple(i -> FT(NaN), Val(nc))), + transform = first, + ) do (wave_forcing, mask, Bsum_or_NaN, B0_or_NaNs), + ( + u_kp1, + u_source, + bf_kp1, + ρ_k, + ρ_kp1, + ρ_source, + z_kp1, + z_k, + source_level, + Bw, + Bn, + cw, + cn, + flag, + level, + source_ampl, + ) + + FT1 = typeof(u_kp1) + kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) # wave number of gravity waves + k2 = kwv * kwv + + fac = FT1(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 + Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height + alp2 = FT1(0.25) / (Hb * Hb) + ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) + + # calculate momentum flux carried by gravity waves with different phase speeds. + B0, Bsum = if level == 1 + mask = StaticBitVector{nc}(_ -> true) + B1 = + wave_source(c, u_source, Bw, Bn, cw, cn, c0, flag, gw_ncval) + Bsum1 = sum(abs, B1) + B1, Bsum1 + else + B0_or_NaNs, Bsum_or_NaN + end + + if level >= source_level - 1 + # check break condition for each gravity waves and calculate momentum flux of breaking gravity waves at each level + # We use the unrolled_reduce function here because it performs better for parallel execution on the GPU, avoiding type instabilities. + (mask, fm) = + unrolled_reduce(Val(nc), (mask, FT1(0.0))) do (mask, fm), (n) + if (mask[n]) == true + c_hat = c[n] - u_kp1 # c0mu + # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. + if c_hat == 0.0 + mask = Base.setindex(mask, false, n) + else + c_hat0 = c[n] - u_source + # define the criterion which determines if wave is reflected at this level (test). + test = abs(c_hat) * kwv - ω_r + if test >= 0.0 + # wave has undergone total internal reflection. remove it from the propagating set. + mask = Base.setindex(mask, false, n) + else + if level == level_end + # this is added in MiMA implementation: + # all momentum flux that escapes across the model top + # is deposited to the extra level being added so that + # momentum flux is conserved + mask = Base.setindex(mask, false, n) + if level >= source_level + fm = fm + B0[n] + end + else + # if wave is not reflected at this level, determine if it is + # breaking at this level (Foc >= 0), or if wave speed relative to + # windspeed has changed sign from its value at the source level + # (c_hat0[n] * c_hat <= 0). if it is above the source level and is + # breaking, then add its momentum flux to the accumulated sum at + # this level. + # set mask=0.0 to remove phase speed band c[n] from the set of active + # waves moving upwards to the next level. + Foc = B0[n] / (c_hat)^3 - fac + if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) + mask = Base.setindex(mask, false, n) + if level >= source_level + fm = fm + B0[n] + end + end + end + end # (test >= 0.0) + + end #(c_hat == 0.0) + end # mask = 0 + return (mask, fm) + end + + # compute the gravity wave momentum flux forcing + # obtained across the entire wave spectrum at this level. + eps = calc_intermitency(ρ_source, source_ampl, nk, FT1(Bsum)) + #calculate intermittency factor + if level >= source_level + rbh = sqrt(ρ_k * ρ_kp1) + wave_forcing = (ρ_source / rbh) * FT1(fm) * eps / (z_kp1 - z_k) + else + wave_forcing = FT1(0.0) + end + end + return (wave_forcing, mask, Bsum, B0) + + end +end +# calculate the intermittency factor eps -> assuming constant Δc. function calc_intermitency(ρ_source_level, source_ampl, nk, Bsum) return (source_ampl / ρ_source_level / nk) / Bsum end function gw_average!(wave_forcing, wave_forcing_m1) - L1 = Operators.LeftBiasedC2F(; bottom = Operators.SetValue(0.0)) + FT = eltype(wave_forcing) + L1 = Operators.LeftBiasedC2F(; bottom = Operators.SetValue(FT(0.0))) L2 = Operators.LeftBiasedF2C(;) wave_forcing_m1 .= L2.(L1.(wave_forcing)) - @. wave_forcing = 0.5 * (wave_forcing + wave_forcing_m1) + @. wave_forcing = FT(0.5) * (wave_forcing + wave_forcing_m1) end @@ -755,12 +652,41 @@ function gw_deposit(wave_forcing_top, wave_forcing, damp_level, level, height) return wave_forcing end -function get_nc(::Val{nc}) where {nc} - nc -end - function field_shiftlevel_up!(ᶜexample_field, ᶜshifted_field, Boundary_value) R1 = Operators.RightBiasedC2F(; top = Operators.SetValue(Boundary_value)) R2 = Operators.RightBiasedF2C(;) ᶜshifted_field .= R2.(R1.(ᶜexample_field)) end + +function wave_source( + c, + u_source, + Bw, + Bn, + cw, + cn, + c0, + flag, + gw_ncval::Val{nc}, +) where {nc} + ntuple( + n -> + sign((c[n] - u_source)) * ( + Bw * exp( + -log(2.0f0) * + ( + (c[n] * flag + (c[n] - u_source) * (1 - flag) - c0) / + cw + )^2, + ) + + Bn * exp( + -log(2.0f0) * + ( + (c[n] * flag + (c[n] - u_source) * (1 - flag) - c0) / + cn + )^2, + ) + ), + Val(nc), + ) +end diff --git a/src/solver/types.jl b/src/solver/types.jl index f887e308028..6405a54377f 100644 --- a/src/solver/types.jl +++ b/src/solver/types.jl @@ -87,7 +87,7 @@ Base.@kwdef struct NonOrographyGravityWave{FT} <: AbstractGravityWave source_height::FT = 15000 Bw::FT = 1.0 Bn::FT = 1.0 - dc::FT = 0.6 + dc::FT = 0.8 cmax::FT = 99.6 c0::FT = 0 nk::FT = 1 diff --git a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl index bee52d3ccd0..47244688a13 100644 --- a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl +++ b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_3d.jl @@ -133,6 +133,7 @@ center_u_zonalave = mean(center_u, dims = 1)[1, :, :, :] center_bf_zonalave = mean(center_bf, dims = 1)[1, :, :, :] center_ρ_zonalave = mean(center_ρ, dims = 1)[1, :, :, :] +#generate domain, space and field column_domain = ClimaCore.Domains.IntervalDomain( ClimaCore.Geometry.ZPoint(0.0) .. ClimaCore.Geometry.ZPoint(47000), boundary_names = (:bottom, :top), @@ -153,7 +154,7 @@ gw_ncval = Val(500) ᶜv = copy(ᶜz) ᶜbf = copy(ᶜz) ᶜlevel = similar(ᶜρ, FT) -u_waveforcing = similar(ᶜu) +u_waveforcing = similar(ᶜv) v_waveforcing = similar(ᶜv) for i in 1:Spaces.nlevels(axes(ᶜρ)) fill!(Fields.level(ᶜlevel, i), i) @@ -170,9 +171,9 @@ scratch = (; temp_field_level = similar(Fields.level(ᶜz, 1), FT), ) +# create input parameter params = (; non_orographic_gravity_wave, scratch) - # Jan month = Dates.month.(time) diff --git a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl index 24f39cbd9c5..ad8680ef54f 100644 --- a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl +++ b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_mima.jl @@ -128,7 +128,6 @@ dTdz = zeros(size(T)) bf = @. (grav / T) * (dTdz + grav / cp_d) bf = @. ifelse(bf < 2.5e-5, sqrt(2.5e-5), sqrt(abs(bf))) -# compute u/v forcings from convective gravity waves param = non_orographic_gravity_wave_param(lat, FT) # nogw forcing @@ -139,6 +138,7 @@ k_damp = findlast(pfull * 100 .< param.gw_damp_pressure) uforcing = zeros(size(lev)) vforcing = zeros(size(lev)) +#generate domain, space and field column_domain = ClimaCore.Domains.IntervalDomain( ClimaCore.Geometry.ZPoint(FT(z[1, 1, end, 1])) .. ClimaCore.Geometry.ZPoint(FT(z[1, 1, 1, 1])), @@ -162,8 +162,9 @@ gw_ncval = Val(333) ᶜv = copy(ᶜz) ᶜbf = copy(ᶜz) ᶜlevel = similar(ᶜρ, FT) +# waveforcing = similar(ᶜu, Tuple{FT, FT}) u_waveforcing = similar(ᶜu) -v_waveforcing = similar(ᶜv) +v_waveforcing = similar(ᶜu) for i in 1:Spaces.nlevels(axes(ᶜρ)) fill!(Fields.level(ᶜlevel, i), i) end @@ -180,10 +181,9 @@ scratch = (; temp_field_level = similar(Fields.level(ᶜz, 1), FT), ) -#params = (; non_orographic_gravity_wave, scratch) - for j in 1:length(lat) non_orographic_gravity_wave = non_orographic_gravity_wave_param(lat[j], FT) + # create input parameters at each level params = (; non_orographic_gravity_wave, scratch) for i in 1:length(lon) for it in 1:length(time) diff --git a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl index b419b743da1..b7edeba8d75 100644 --- a/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl +++ b/test/parameterized_tendencies/gravity_wave/non_orographic_gravity_wave/nogw_test_single_column.jl @@ -11,6 +11,8 @@ import ClimaCore.Spaces as Spaces import ClimaCore.Fields as Fields import ClimaCore.Geometry as Geometry import ClimaCore.Operators as Operators +import ClimaCore.Domains as Domains +import ClimaCore: InputOutput, Meshes, Spaces, Quadratures include("../gw_plotutils.jl") @@ -82,9 +84,9 @@ end (; lon, lat, lev, time, gZ, T, u) = nt # compute density and buoyancy frequency -R_d = 287.0 -grav = 9.8 -cp_d = 1004.0 +R_d = FT(287.0) +grav = FT(9.8) +cp_d = FT(1004.0) Z = gZ ./ grav ρ = ones(size(T)) .* reshape(lev, (1, length(lev), 1)) ./ T / R_d @@ -133,21 +135,56 @@ center_u_mean = mean(center_u, dims = 1)[1, :, :] center_bf_mean = mean(center_bf, dims = 1)[1, :, :] center_ρ_mean = mean(center_ρ, dims = 1)[1, :, :] -# monthly ave Jan, April, July, Oct +# Generate domain, space and field +Δx = FT(1) # Note: This value shouldn't matter, since we only have 1 column. +quad = Quadratures.GL{1}() + +x_domain = Domains.IntervalDomain( + Geometry.XPoint(zero(Δx)), + Geometry.XPoint(Δx); + periodic = true, +) +y_domain = Domains.IntervalDomain( + Geometry.YPoint(zero(Δx)), + Geometry.YPoint(Δx); + periodic = true, +) +domain = Domains.RectangleDomain(x_domain, y_domain) +horizontal_mesh = Meshes.RectilinearMesh(domain, 1, 1) -column_domain = ClimaCore.Domains.IntervalDomain( - ClimaCore.Geometry.ZPoint(0.0) .. ClimaCore.Geometry.ZPoint(50000.0), +comms_ctx = ClimaComms.SingletonCommsContext{ClimaComms.CPUSingleThreaded}( + ClimaComms.CPUSingleThreaded(), +) +topology = ClimaCore.Topologies.Topology2D( + comms_ctx, + horizontal_mesh, + ClimaCore.Topologies.spacefillingcurve(horizontal_mesh), +) +h_space = Spaces.SpectralElementSpace2D(topology, quad;) + +h_grid = Spaces.grid(h_space) +z_domain = Domains.IntervalDomain( + Geometry.ZPoint(FT(0.0)), + Geometry.ZPoint(FT(50000.0)); boundary_names = (:bottom, :top), ) +z_stretch = Meshes.Uniform() +z_mesh = Meshes.IntervalMesh(z_domain, z_stretch; nelems = 50) -column_mesh = ClimaCore.Meshes.IntervalMesh(column_domain, nelems = 50) +device = ClimaComms.device(h_space) +z_topology = ClimaCore.Topologies.IntervalTopology( + ClimaComms.SingletonCommsContext(device), + z_mesh, +) +z_grid = ClimaCore.Grids.FiniteDifferenceGrid(z_topology) +hypsography = ClimaCore.Hypsography.Flat() +grid = + ClimaCore.Grids.ExtrudedFiniteDifferenceGrid(h_grid, z_grid, hypsography;) -column_center_space = ClimaCore.Spaces.CenterFiniteDifferenceSpace(column_mesh) -# construct the face space from the center one -column_face_space = - ClimaCore.Spaces.FaceFiniteDifferenceSpace(column_center_space) +center_space = Spaces.CenterExtrudedFiniteDifferenceSpace(grid) +face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(grid) -coord = ClimaCore.Fields.coordinate_field(column_center_space) +coord = ClimaCore.Fields.coordinate_field(center_space) gw_ncval = Val(501) ᶜz = coord.z @@ -157,7 +194,7 @@ gw_ncval = Val(501) ᶜbf = copy(ᶜz) ᶜlevel = similar(ᶜρ, FT) u_waveforcing = similar(ᶜu) -v_waveforcing = similar(ᶜv) +v_waveforcing = similar(ᶜu) for i in 1:Spaces.nlevels(axes(ᶜρ)) fill!(Fields.level(ᶜlevel, i), i) end @@ -183,6 +220,7 @@ scratch = (; temp_field_level = similar(Fields.level(ᶜz, 1), FT), ) +# creat input parameters params = (; non_orographic_gravity_wave, scratch) # Jan