From a540321a2e73fabdbf06b486859e4f5fb2d089c9 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 26 Sep 2024 10:03:01 -0400 Subject: [PATCH] Hoist UniversalSize computation outside of kernels --- ext/cuda/data_layouts_threadblock.jl | 4 ++-- ext/cuda/matrix_fields_multiple_field_solve.jl | 8 ++++---- ext/cuda/matrix_fields_single_field_solve.jl | 2 +- ext/cuda/operators_integral.jl | 13 ++++++++----- ext/cuda/operators_thomas_algorithm.jl | 6 +++--- src/Operators/thomas_algorithm.jl | 1 + 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index c7ef4681fb..02a0aeff6c 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -185,7 +185,7 @@ end @assert prod((Nij, Nij, Nh_thread)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nh_thread))),$n_max_threads)" return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,)) end -@inline function columnwise_universal_index() +@inline function columnwise_universal_index(us::UniversalSize) (i, j, th) = CUDA.threadIdx() (bh,) = CUDA.blockIdx() h = th + (bh - 1) * CUDA.blockDim().z @@ -207,7 +207,7 @@ end @assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)" return (; threads = (Nij, Nij, Nnames), blocks = (Nh,)) end -@inline function multiple_field_solve_universal_index() +@inline function multiple_field_solve_universal_index(us::UniversalSize) (i, j, iname) = CUDA.threadIdx() (h,) = CUDA.blockIdx() return (CartesianIndex((i, j, 1, 1, h)), iname) diff --git a/ext/cuda/matrix_fields_multiple_field_solve.jl b/ext/cuda/matrix_fields_multiple_field_solve.jl index 57a7c5c0cb..3955aabaa7 100644 --- a/ext/cuda/matrix_fields_multiple_field_solve.jl +++ b/ext/cuda/matrix_fields_multiple_field_solve.jl @@ -33,9 +33,9 @@ NVTX.@annotate function multiple_field_solve!( device = ClimaComms.device(x[first(names)]) - args = (device, caches, xs, As, bs, x1, Val(Nnames)) - us = UniversalSize(Fields.field_values(x1)) + args = (device, caches, xs, As, bs, x1, us, Val(Nnames)) + nitems = Ni * Nj * Nh * Nnames threads = threads_via_occupancy(multiple_field_solve_kernel!, args) n_max_threads = min(threads, nitems) @@ -85,11 +85,11 @@ function multiple_field_solve_kernel!( As, bs, x1, + us::UniversalSize, ::Val{Nnames}, ) where {Nnames} @inbounds begin - us = UniversalSize(Fields.field_values(x1)) - (I, iname) = multiple_field_solve_universal_index() + (I, iname) = multiple_field_solve_universal_index(us) if multiple_field_solve_is_valid_index(I, us) (i, j, _, _, h) = I.I generated_single_field_solve!( diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index 59c9e1424a..b1ca849733 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -30,7 +30,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b) end function single_field_solve_kernel!(device, cache, x, A, b, us) - I = columnwise_universal_index() + I = columnwise_universal_index(us) if columnwise_is_valid_index(I, us) (i, j, _, _, h) = I.I _single_field_solve!( diff --git a/ext/cuda/operators_integral.jl b/ext/cuda/operators_integral.jl index ef116daa0d..7b344b511e 100644 --- a/ext/cuda/operators_integral.jl +++ b/ext/cuda/operators_integral.jl @@ -19,6 +19,7 @@ function column_reduce_device!( space, ) where {F, T} Ni, Nj, _, _, Nh = size(Fields.field_values(output)) + us = UniversalSize(Fields.field_values(output)) args = ( single_column_reduce!, f, @@ -27,8 +28,8 @@ function column_reduce_device!( strip_space(input, space), init, space, + us, ) - us = UniversalSize(Fields.field_values(output)) nitems = Ni * Nj * Nh threads = threads_via_occupancy(bycolumn_kernel!, args) n_max_threads = min(threads, nitems) @@ -50,7 +51,8 @@ function column_accumulate_device!( init, space, ) where {F, T} - us = UniversalSize(Fields.field_values(output)) + out_fv = Fields.field_values(output) + us = UniversalSize(out_fv) args = ( single_column_accumulate!, f, @@ -59,8 +61,9 @@ function column_accumulate_device!( strip_space(input, space), init, space, + us, ) - Ni, Nj, _, _, Nh = size(Fields.field_values(output)) + (Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us) nitems = Ni * Nj * Nh threads = threads_via_occupancy(bycolumn_kernel!, args) n_max_threads = min(threads, nitems) @@ -81,12 +84,12 @@ bycolumn_kernel!( input, init, space, + us::DataLayouts.UniversalSize, ) where {S, F, T} = if space isa Spaces.FiniteDifferenceSpace single_column_function!(f, transform, output, input, init, space) else - I = columnwise_universal_index() - us = UniversalSize(Fields.field_values(output)) + I = columnwise_universal_index(us) if columnwise_is_valid_index(I, us) (i, j, _, _, h) = I.I single_column_function!( diff --git a/ext/cuda/operators_thomas_algorithm.jl b/ext/cuda/operators_thomas_algorithm.jl index 11eea95fed..d25e26fd3d 100644 --- a/ext/cuda/operators_thomas_algorithm.jl +++ b/ext/cuda/operators_thomas_algorithm.jl @@ -6,7 +6,7 @@ import CUDA using CUDA: @cuda function column_thomas_solve!(::ClimaComms.CUDADevice, A, b) us = UniversalSize(Fields.field_values(A)) - args = (A, b) + args = (A, b, us) Ni, Nj, _, _, Nh = size(Fields.field_values(A)) threads = threads_via_occupancy(thomas_algorithm_kernel!, args) nitems = Ni * Nj * Nh @@ -23,9 +23,9 @@ end function thomas_algorithm_kernel!( A::Fields.ExtrudedFiniteDifferenceField, b::Fields.ExtrudedFiniteDifferenceField, + us::DataLayouts.UniversalSize, ) - I = columnwise_universal_index() - us = UniversalSize(Fields.field_values(A)) + I = columnwise_universal_index(us) if columnwise_is_valid_index(I, us) (i, j, _, _, h) = I.I thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h)) diff --git a/src/Operators/thomas_algorithm.jl b/src/Operators/thomas_algorithm.jl index 791d3194fa..3ea18cb1ff 100644 --- a/src/Operators/thomas_algorithm.jl +++ b/src/Operators/thomas_algorithm.jl @@ -17,6 +17,7 @@ column_thomas_solve!(::ClimaComms.AbstractCPUDevice, A, b) = thomas_algorithm_kernel!( A::Fields.FiniteDifferenceField, b::Fields.FiniteDifferenceField, + us::DataLayouts.UniversalSize, ) = thomas_algorithm!(A, b) function thomas_algorithm!(