Skip to content

Commit

Permalink
Fix high res threading config for SEM kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 23, 2024
1 parent bd20629 commit 95e4b45
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
39 changes: 39 additions & 0 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,42 @@ end
end
@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
1 I[5] DataLayouts.get_Nh(us)

##### spectral kernel partition
@inline function spectral_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer = 256;
)
(Nij, _, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = min(fld(n_max_threads, Nq * Nq), maximum_allowable_threads()[3])
Nvblocks = cld(Nv, Nvthreads)
@assert prod((Nq, Nq, Nvthreads)) n_max_threads "threads,n_max_threads=($(prod((Nq, Nq, Nvthreads))),$n_max_threads)"
@assert Nq * Nq n_max_threads
return (; threads = (Nq, Nq, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
end
@inline function spectral_universal_index(
space::Spaces.AbstractSpectralElementSpace,
)
i = threadIdx().x
j = threadIdx().y
k = threadIdx().z
h = blockIdx().x
vid = k + (blockIdx().y - 1) * blockDim().z
if space isa Spaces.AbstractSpectralElementSpace
v = nothing
elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace
v = vid - half
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
v = vid
else
error("Invalid space")
end
ij = CartesianIndex((i, j))
slabidx = Fields.SlabIndex(v, h)
return (ij, slabidx)
end
@inline spectral_is_valid_index(
space::Spaces.AbstractSpectralElementSpace,
ij,
slabidx,
) = Operators.is_valid_index(space, ij, slabidx)
39 changes: 8 additions & 31 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,20 @@ function Base.copyto!(
},
)
space = axes(out)
QS = Spaces.quadrature_style(space)
Nq = Quadratures.degrees_of_freedom(QS)
Nh = Topologies.nlocalelems(Spaces.topology(space))
Nv = Spaces.nlevels(space)
max_threads = 256
@assert Nq * Nq max_threads
Nvthreads = fld(max_threads, Nq * Nq)
Nvblocks = cld(Nv, Nvthreads)
us = UniversalSize(Fields.field_values(out))
# executed
p = spectral_partition(us)
args = (
strip_space(out, space),
strip_space(sbc, space),
space,
Val(Nvthreads),
Val(p.Nvthreads),
)
auto_launch!(
copyto_spectral_kernel!,
args;
threads_s = (Nq, Nq, Nvthreads),
blocks_s = (Nh, Nvblocks),
threads_s = p.threads,
blocks_s = p.blocks,
)
return out
end
Expand All @@ -66,32 +60,15 @@ function copyto_spectral_kernel!(
::Val{Nvt},
) where {Nvt}
@inbounds begin
i = threadIdx().x
j = threadIdx().y
k = threadIdx().z
h = blockIdx().x
vid = k + (blockIdx().y - 1) * blockDim().z
# allocate required shmem

sbc_reconstructed =
Operators.reconstruct_placeholder_broadcasted(space, sbc)
sbc_shmem = allocate_shmem(Val(Nvt), sbc_reconstructed)


# can loop over blocks instead?
if space isa Spaces.AbstractSpectralElementSpace
v = nothing
elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace
v = vid - half
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
v = vid
else
error("Invalid space")
end
ij = CartesianIndex((i, j))
slabidx = Fields.SlabIndex(v, h)
# v may potentially be out-of-range: any time memory is accessed, it
# should be checked by a call to is_valid_index(space, ij, slabidx)
(ij, slabidx) = spectral_universal_index(space)
# v in `slabidx` may potentially be out-of-range: any time memory is
# accessed, it should be checked by a call to is_valid_index(space, ij, slabidx)

# resolve_shmem! needs to be called even when out of range, so that
# sync_threads() is invoked collectively
Expand Down

0 comments on commit 95e4b45

Please sign in to comment.