Skip to content

Commit

Permalink
Add support for mult field solve for diag mat row
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 3, 2024
1 parent f7c1dbb commit d1e4426
Showing 1 changed file with 34 additions and 12 deletions.
46 changes: 34 additions & 12 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,42 @@ function single_field_solve_kernel!(device, cache, x, A, b, us)
return nothing
end

function _single_field_solve_diag_matrix_row!(
device::ClimaComms.CUDADevice,
cache,
x,
A,
b,
)
Aⱼs = unzip_tuple_field_values(Fields.field_values(A.entries))
(A₀,) = Aⱼs
vi = vindex
x_data = Fields.field_values(x)
b_data = Fields.field_values(b)
Nv = DataLayouts.nlevels(x_data)
@inbounds for v in 1:Nv
x_data[vi(v)] = inv(A₀[vi(v)]) b_data[vi(v)]
end
end

function _single_field_solve!(
::ClimaComms.CUDADevice,
cache::Fields.ColumnField,
x::Fields.ColumnField,
A::Fields.ColumnField,
b::Fields.ColumnField,
device::ClimaComms.CUDADevice,
cache::Fields.Field,
x::Fields.Field,
A::Fields.Field,
b::Fields.Field,
)
band_matrix_solve_local_mem!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
)
if eltype(A) <: MatrixFields.DiagonalMatrixRow
_single_field_solve_diag_matrix_row!(device, cache, x, A, b)
else
band_matrix_solve_local_mem!(
eltype(A),
unzip_tuple_field_values(Fields.field_values(cache)),
Fields.field_values(x),
unzip_tuple_field_values(Fields.field_values(A.entries)),
Fields.field_values(b),
)
end
end

function _single_field_solve!(
Expand Down

0 comments on commit d1e4426

Please sign in to comment.