diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index de94621009..77fbb00d1e 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -170,6 +170,25 @@ end ##### Custom partitions ##### +##### linear partition +@inline function linear_partition( + us::DataLayouts.UniversalSize, + n_max_threads::Integer, +) + nitems = prod(DataLayouts.universal_size(us)) + threads = min(nitems, n_max_threads) + blocks = cld(nitems, threads) + return (; threads, blocks) +end +@inline function linear_universal_index(us::UniversalSize) + i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x + inds = DataLayouts.universal_size(us) + CI = CartesianIndices(map(x -> Base.OneTo(x), inds)) + return (CI[i], i) +end +@inline linear_is_valid_index(i::Integer, us::UniversalSize) = + 1 ≤ i ≤ DataLayouts.get_N(us) + ##### Column-wise @inline function columnwise_partition( us::DataLayouts.UniversalSize, diff --git a/ext/cuda/operators_finite_difference.jl b/ext/cuda/operators_finite_difference.jl index c93b4e0797..677ac7691d 100644 --- a/ext/cuda/operators_finite_difference.jl +++ b/ext/cuda/operators_finite_difference.jl @@ -21,12 +21,13 @@ function Base.copyto!( bounds = Operators.window_bounds(space, bc) out_fv = Fields.field_values(out) us = DataLayouts.UniversalSize(out_fv) + nitems = prod(DataLayouts.universal_size(us)) args = (strip_space(out, space), strip_space(bc, space), axes(out), bounds, us) threads = threads_via_occupancy(copyto_stencil_kernel!, args) - n_max_threads = min(threads, get_N(us)) - p = partition(out_fv, n_max_threads) + n_max_threads = min(threads, nitems) + p = linear_partition(us, n_max_threads) auto_launch!( copyto_stencil_kernel!, @@ -40,9 +41,8 @@ import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh function copyto_stencil_kernel!(out, bc, space, bds, us) @inbounds begin - out_fv = Fields.field_values(out) - I = universal_index(out_fv) - if is_valid_index(out_fv, I, us) + (I, i_linear) = linear_universal_index(us) + if linear_is_valid_index(i_linear, us) (li, lw, rw, ri) = bds (i, j, _, v, h) = I.I hidx = (i, j, h) diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 4b681164b3..9daaaf5fa0 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -87,8 +87,7 @@ is excluded and is returned as 1. Statically returns `prod((Ni, Nj, Nv, Nh))` """ -@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} = - prod((Ni, Nj, Nv, Nh)) +@inline get_N(us::UniversalSize) = prod(universal_size(us)) """ get_Nv(::UniversalSize)