Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix high res threading config for SEM kernels #2001

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,40 @@ end
end
@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
1 ≤ I[5] ≤ DataLayouts.get_Nh(us)

##### spectral kernel partition
@inline function spectral_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer = 256;
)
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = min(fld(n_max_threads, Nq * Nq), maximum_allowable_threads()[3])
Nvblocks = cld(Nv, Nvthreads)
@assert prod((Nq, Nq, Nvthreads)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nq, Nq, Nvthreads))),$n_max_threads)"
@assert Nq * Nq ≤ n_max_threads
return (; threads = (Nq, Nq, Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
end
@inline function spectral_universal_index(space::Spaces.AbstractSpace)
i = threadIdx().x
j = threadIdx().y
k = threadIdx().z
h = blockIdx().x
vid = k + (blockIdx().y - 1) * blockDim().z
if space isa Spaces.AbstractSpectralElementSpace
v = nothing
elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace
v = vid - half
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
v = vid
else
error("Invalid space")
end
ij = CartesianIndex((i, j))
slabidx = Fields.SlabIndex(v, h)
return (ij, slabidx)
end
@inline spectral_is_valid_index(
space::Spaces.AbstractSpectralElementSpace,
ij,
slabidx,
) = Operators.is_valid_index(space, ij, slabidx)
39 changes: 8 additions & 31 deletions ext/cuda/operators_spectral_element.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,20 @@ function Base.copyto!(
},
)
space = axes(out)
QS = Spaces.quadrature_style(space)
Nq = Quadratures.degrees_of_freedom(QS)
Nh = Topologies.nlocalelems(Spaces.topology(space))
Nv = Spaces.nlevels(space)
max_threads = 256
@assert Nq * Nq ≤ max_threads
Nvthreads = fld(max_threads, Nq * Nq)
Nvblocks = cld(Nv, Nvthreads)
us = UniversalSize(Fields.field_values(out))
# executed
p = spectral_partition(us)
args = (
strip_space(out, space),
strip_space(sbc, space),
space,
Val(Nvthreads),
Val(p.Nvthreads),
)
auto_launch!(
copyto_spectral_kernel!,
args;
threads_s = (Nq, Nq, Nvthreads),
blocks_s = (Nh, Nvblocks),
threads_s = p.threads,
blocks_s = p.blocks,
)
return out
end
Expand All @@ -66,32 +60,15 @@ function copyto_spectral_kernel!(
::Val{Nvt},
) where {Nvt}
@inbounds begin
i = threadIdx().x
j = threadIdx().y
k = threadIdx().z
h = blockIdx().x
vid = k + (blockIdx().y - 1) * blockDim().z
# allocate required shmem

sbc_reconstructed =
Operators.reconstruct_placeholder_broadcasted(space, sbc)
sbc_shmem = allocate_shmem(Val(Nvt), sbc_reconstructed)


# can loop over blocks instead?
if space isa Spaces.AbstractSpectralElementSpace
v = nothing
elseif space isa Spaces.FaceExtrudedFiniteDifferenceSpace
v = vid - half
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
v = vid
else
error("Invalid space")
end
ij = CartesianIndex((i, j))
slabidx = Fields.SlabIndex(v, h)
# v may potentially be out-of-range: any time memory is accessed, it
# should be checked by a call to is_valid_index(space, ij, slabidx)
(ij, slabidx) = spectral_universal_index(space)
# v in `slabidx` may potentially be out-of-range: any time memory is
# accessed, it should be checked by a call to is_valid_index(space, ij, slabidx)

# resolve_shmem! needs to be called even when out of range, so that
# sync_threads() is invoked collectively
Expand Down
2 changes: 1 addition & 1 deletion test/Spaces/distributed_cuda/ddss4.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ pid, nprocs = ClimaComms.init(context)
end
p = @allocated Spaces.weighted_dss!(y0, dss_buffer)
if pid == 1
@test p ≤ 7776
@test p ≤ 410296
end

end
Loading