From 8e951871297adda64af9a6be87b9156b0b58d5da Mon Sep 17 00:00:00 2001 From: sriharshakandala Date: Thu, 27 Jul 2023 18:56:41 -0700 Subject: [PATCH 1/2] Implement thread-per-node stencil operator kernels --- src/Operators/finitedifference.jl | 72 +++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 79b4d0c599..b89b50a322 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3420,7 +3420,6 @@ function strip_space(bc::StencilBroadcasted{Style}, parent_space) where {Style} ) end - function Base.copyto!( out::Field, bc::Union{ @@ -3437,23 +3436,72 @@ function Base.copyto!( Nq = 1 Nh = 1 end - bounds = window_bounds(space, bc) - # executed - @cuda threads = (Nq, Nq) blocks = (Nh,) copyto_stencil_kernel!( - strip_space(out, space), - strip_space(bc, space), + (li, lw, rw, ri) = bounds = window_bounds(space, bc) + ninteriornodes = rw - lw + 1 + + max_threads = 256 + nitemsbdy = Nq * Nq * Nh # # of independent boundary items + nitemsint = ninteriornodes * Nq * Nq * Nh # # of independent interior items + (nthreadsbdy, nblocksbdy) = Spaces._configure_threadblock(nitemsbdy) + (nthreadsint, nblocksint) = Spaces._configure_threadblock(nitemsint) + isnotperiodic = !Topologies.isperiodic(Spaces.vertical_topology(space)) + strip_space_out = strip_space(out, space) + strip_space_bc = strip_space(bc, space) + # left and right windows, if applicable + isnotperiodic && + @cuda threads = (nthreadsbdy,) blocks = (nblocksbdy,) copyto_stencil_bdy_kernel!( + strip_space_out, + strip_space_bc, + axes(out), + bounds, + Nq, + Nh, + ) + # interior nodes + @cuda threads = (nthreadsint,) blocks = (nblocksint,) copyto_stencil_interior_kernel!( + strip_space_out, + strip_space_bc, axes(out), bounds, + ninteriornodes, + Nq, + Nh, ) return out end -function copyto_stencil_kernel!(out, bc, space, bds) - i = threadIdx().x - j = threadIdx().y - h = blockIdx().x - hidx = (i, j, h) - apply_stencil!(space, out, bc, hidx, bds) +function copyto_stencil_bdy_kernel!(out, bc, space, bds, Nq, Nh) + gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gid ≤ Nq * Nq * Nh + (li, lw, rw, ri) = bds + hidx = Spaces._get_idx((Nq, Nq, Nh), gid) + lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}() + rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}() + @inbounds for idx in li:(lw - 1) + setidx!(space, out, idx, hidx, getidx(space, bc, lbw, idx, hidx)) + end + @inbounds for idx in (rw + 1):ri + setidx!(space, out, idx, hidx, getidx(space, bc, rbw, idx, hidx)) + end + end + return nothing +end + +function copyto_stencil_interior_kernel!(out, bc, space, bds, nnodes, Nq, Nh) + gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gid ≤ nnodes * Nq * Nq * Nh + (_, lw, rw, _) = bds + (ndidx, i, j, h) = Spaces._get_idx((nnodes, Nq, Nq, Nh), gid) + hidx = (i, j, h) + ndidx += lw - 1 + setidx!( + space, + out, + ndidx, + hidx, + getidx(space, bc, Interior(), ndidx, hidx), + ) + end return nothing end From 8d49db37c40046c5cb4bf55e616c6df956403a42 Mon Sep 17 00:00:00 2001 From: sriharshakandala Date: Fri, 28 Jul 2023 16:47:51 -0700 Subject: [PATCH 2/2] Update _configure_threadblock --- src/Spaces/dss_cuda.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/Spaces/dss_cuda.jl b/src/Spaces/dss_cuda.jl index f32a0d63ab..b649a56bc9 100644 --- a/src/Spaces/dss_cuda.jl +++ b/src/Spaces/dss_cuda.jl @@ -1,12 +1,15 @@ _max_threads_cuda() = 256 -function _configure_threadblock(nitems) - nthreads = min(_max_threads_cuda(), nitems) +function _configure_threadblock(max_threads, nitems) + nthreads = min(max_threads, nitems) nblocks = cld(nitems, nthreads) return (nthreads, nblocks) end +_configure_threadblock(nitems) = + _configure_threadblock(_max_threads_cuda(), nitems) + function dss_load_perimeter_data!( ::ClimaComms.CUDADevice, dss_buffer::DSSBuffer,