diff --git a/ext/cuda/matrix_fields_single_field_solve.jl b/ext/cuda/matrix_fields_single_field_solve.jl index b1ca849733..98f51ed3d8 100644 --- a/ext/cuda/matrix_fields_single_field_solve.jl +++ b/ext/cuda/matrix_fields_single_field_solve.jl @@ -44,20 +44,52 @@ function single_field_solve_kernel!(device, cache, x, A, b, us) return nothing end +unrolled_unzip_tuple_field_values(data) = + unrolled_unzip_tuple_field_values(data, propertynames(data)) +unrolled_unzip_tuple_field_values(data, pn::Tuple) = ( + getproperty(data, Val(first(pn))), + unrolled_unzip_tuple_field_values(data, Base.tail(pn))..., +) +unrolled_unzip_tuple_field_values(data, pn::Tuple{Any}) = + (getproperty(data, Val(first(pn))),) +unrolled_unzip_tuple_field_values(data, pn::Tuple{}) = () + +function _single_field_solve_diag_matrix_row!( + device::ClimaComms.CUDADevice, + cache, + x, + A, + b, +) + Aⱼs = unrolled_unzip_tuple_field_values(Fields.field_values(A.entries)) + (A₀,) = Aⱼs + vi = vindex + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + Nv = DataLayouts.nlevels(x_data) + @inbounds for v in 1:Nv + x_data[vi(v)] = inv(A₀[vi(v)]) ⊠ b_data[vi(v)] + end +end + function _single_field_solve!( - ::ClimaComms.CUDADevice, - cache::Fields.ColumnField, - x::Fields.ColumnField, - A::Fields.ColumnField, - b::Fields.ColumnField, + device::ClimaComms.CUDADevice, + cache::Fields.Field, + x::Fields.Field, + A::Fields.Field, + b::Fields.Field, ) - band_matrix_solve_local_mem!( - eltype(A), - unzip_tuple_field_values(Fields.field_values(cache)), - Fields.field_values(x), - unzip_tuple_field_values(Fields.field_values(A.entries)), - Fields.field_values(b), - ) + if eltype(A) <: MatrixFields.DiagonalMatrixRow + _single_field_solve_diag_matrix_row!(device, cache, x, A, b) + else + band_matrix_solve_local_mem!( + eltype(A), + unzip_tuple_field_values(Fields.field_values(cache)), + Fields.field_values(x), + unzip_tuple_field_values(Fields.field_values(A.entries)), + Fields.field_values(b), + ) + end end function _single_field_solve!(