Skip to content

Commit

Permalink
Fix gpu column integrals
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 18, 2023
1 parent 0e7a5f9 commit 14ae299
Show file tree
Hide file tree
Showing 2 changed files with 708 additions and 656 deletions.
61 changes: 56 additions & 5 deletions src/Operators/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,74 @@
The definite vertical column integral, `col∫field`, of field `field`.
"""
function column_integral_definite!(col∫field::Fields.Field, field::Fields.Field)
column_integral_definite!(col∫field::Fields.Field, field::Fields.Field) =
column_integral_definite!(
ClimaComms.device(axes(col∫field)),
col∫field,
field,
)

function column_integral_definite!(
::ClimaComms.CUDADevice,
col∫field::Fields.Field,
field::Fields.Field,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(col∫field))
nthreads, nblocks = Spaces._configure_threadblock(Ni * Nj * Nh)
@cuda threads = nthreads blocks = nblocks column_integral_definite_kernel!(
col∫field,
field,
)
end

function column_integral_definite_kernel!(
col∫field::Fields.SpectralElementField,
field::Fields.ExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(field))
if idx <= Ni * Nj * Nh
i, j, h = Spaces._get_idx((Ni, Nj, Nh), idx)
_column_integral_definite!(
Spaces.column(col∫field, i, j, h),
Spaces.column(field, i, j, h),
)
end
return nothing
end

column_integral_definite_kernel!(
col∫field::Fields.PointField,
field::Fields.FiniteDifferenceField,
) = _column_integral_definite!(col∫field, field)

function column_integral_definite!(
::ClimaComms.AbstractCPUDevice,
col∫field::Fields.SpectralElementField,
field::Fields.ExtrudedFiniteDifferenceField,
)
Fields.bycolumn(axes(field)) do colidx
column_integral_definite!(col∫field[colidx], field[colidx])
_column_integral_definite!(col∫field[colidx], field[colidx])
nothing
end
return nothing
end

function column_integral_definite!(
column_integral_definite!(
::ClimaComms.AbstractCPUDevice,
col∫field::Fields.PointField,
field::Fields.FiniteDifferenceField,
) = _column_integral_definite!(col∫field, field)

function _column_integral_definite!(
col∫field::Fields.PointField,
field::Fields.ColumnField,
)
@inbounds col∫field[] = column_integral_definite(field)
@inbounds col∫field[] = _column_integral_definite(field)
return nothing
end

function column_integral_definite(field::Fields.ColumnField)
function _column_integral_definite(field::Fields.ColumnField)
field_data = Fields.field_values(field)
Δz_data = Spaces.Δz_data(axes(field))
Nv = Spaces.nlevels(axes(field))
Expand Down
Loading

0 comments on commit 14ae299

Please sign in to comment.