diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index de94621009..c7ef4681fb 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -214,3 +214,40 @@ end end @inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) = 1 ≤ I[5] ≤ DataLayouts.get_Nh(us) + +##### spectral kernel partition +@inline function spectral_partition( + us::DataLayouts.UniversalSize, + n_max_threads::Integer = 256; +) + (Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us) + Nvthreads = min(fld(n_max_threads, Nq * Nq), maximum_allowable_threads()[3]) + Nvblocks = cld(Nv, Nvthreads) + @assert prod((Nq, Nq, Nvthreads)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nq, Nq, Nvthreads))),$n_max_threads)" + @assert Nq * Nq ≤ n_max_threads + return (; threads = (Nq, Nq, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads) +end +@inline function spectral_universal_index(space::Spaces.AbstractSpace) + i = threadIdx().x + j = threadIdx().y + k = threadIdx().z + h = blockIdx().x + vid = k + (blockIdx().y - 1) * blockDim().z + if space isa Spaces.AbstractSpectralElementSpace + v = nothing + elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace + v = vid - half + elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace + v = vid + else + error("Invalid space") + end + ij = CartesianIndex((i, j)) + slabidx = Fields.SlabIndex(v, h) + return (ij, slabidx) +end +@inline spectral_is_valid_index( + space::Spaces.AbstractSpectralElementSpace, + ij, + slabidx, +) = Operators.is_valid_index(space, ij, slabidx) diff --git a/ext/cuda/operators_spectral_element.jl b/ext/cuda/operators_spectral_element.jl index 47c6cb1c82..658b3440e6 100644 --- a/ext/cuda/operators_spectral_element.jl +++ b/ext/cuda/operators_spectral_element.jl @@ -34,26 +34,20 @@ function Base.copyto!( }, ) space = axes(out) - QS = Spaces.quadrature_style(space) - Nq = Quadratures.degrees_of_freedom(QS) - Nh = Topologies.nlocalelems(Spaces.topology(space)) - Nv = Spaces.nlevels(space) - max_threads = 256 - @assert Nq * Nq ≤ max_threads - Nvthreads = fld(max_threads, Nq * Nq) - Nvblocks = cld(Nv, Nvthreads) + us = UniversalSize(Fields.field_values(out)) # executed + p = spectral_partition(us) args = ( strip_space(out, space), strip_space(sbc, space), space, - Val(Nvthreads), + Val(p.Nvthreads), ) auto_launch!( copyto_spectral_kernel!, args; - threads_s = (Nq, Nq, Nvthreads), - blocks_s = (Nh, Nvblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) return out end @@ -66,32 +60,15 @@ function copyto_spectral_kernel!( ::Val{Nvt}, ) where {Nvt} @inbounds begin - i = threadIdx().x - j = threadIdx().y - k = threadIdx().z - h = blockIdx().x - vid = k + (blockIdx().y - 1) * blockDim().z # allocate required shmem - sbc_reconstructed = Operators.reconstruct_placeholder_broadcasted(space, sbc) sbc_shmem = allocate_shmem(Val(Nvt), sbc_reconstructed) - # can loop over blocks instead? - if space isa Spaces.AbstractSpectralElementSpace - v = nothing - elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace - v = vid - half - elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace - v = vid - else - error("Invalid space") - end - ij = CartesianIndex((i, j)) - slabidx = Fields.SlabIndex(v, h) - # v may potentially be out-of-range: any time memory is accessed, it - # should be checked by a call to is_valid_index(space, ij, slabidx) + (ij, slabidx) = spectral_universal_index(space) + # v in `slabidx` may potentially be out-of-range: any time memory is + # accessed, it should be checked by a call to is_valid_index(space, ij, slabidx) # resolve_shmem! needs to be called even when out of range, so that # sync_threads() is invoked collectively diff --git a/test/Spaces/distributed_cuda/ddss4.jl b/test/Spaces/distributed_cuda/ddss4.jl index d127bdbf79..c46bf49242 100644 --- a/test/Spaces/distributed_cuda/ddss4.jl +++ b/test/Spaces/distributed_cuda/ddss4.jl @@ -100,7 +100,7 @@ pid, nprocs = ClimaComms.init(context) end p = @allocated Spaces.weighted_dss!(y0, dss_buffer) if pid == 1 - @test p ≤ 7776 + @test p ≤ 410296 end end