From f15ba20c6af574a845d499e07d0a7d84c2963fd5 Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Fri, 26 May 2023 18:15:17 -0700 Subject: [PATCH] Add new MatrixFields module, along with unit tests and performance tests --- Project.toml | 1 + docs/Manifest.toml | 8 +- src/ClimaCore.jl | 1 + src/Fields/mapreduce.jl | 2 +- src/Geometry/axistensors.jl | 23 +- src/MatrixFields/MatrixFields.jl | 35 + src/MatrixFields/band_matrix_row.jl | 153 ++++ src/MatrixFields/matrix_field_utils.jl | 85 ++ src/MatrixFields/matrix_multiplication.jl | 338 ++++++++ src/MatrixFields/rmul_with_projection.jl | 118 +++ src/RecursiveApply/RecursiveApply.jl | 95 ++- .../MatrixFields/matrix_field_broadcasting.jl | 762 ++++++++++++++++++ test/MatrixFields/rmul_with_projection.jl | 127 +++ test/runtests.jl | 3 + 14 files changed, 1739 insertions(+), 12 deletions(-) create mode 100644 src/MatrixFields/MatrixFields.jl create mode 100644 src/MatrixFields/band_matrix_row.jl create mode 100644 src/MatrixFields/matrix_field_utils.jl create mode 100644 src/MatrixFields/matrix_multiplication.jl create mode 100644 src/MatrixFields/rmul_with_projection.jl create mode 100644 test/MatrixFields/matrix_field_broadcasting.jl create mode 100644 test/MatrixFields/rmul_with_projection.jl diff --git a/Project.toml b/Project.toml index 82837036ca..8f1272e457 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.10.39" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 990879d0c0..00021ea2b8 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -98,6 +98,12 @@ git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" version = "0.4.2" +[[deps.BandedMatrices]] +deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "PrecompileTools", "SparseArrays"] +git-tree-sha1 = "9ad46355045491b12eab409dee73e9de46293aa2" +uuid = "aae01518-5342-5314-be14-df237901396f" +version = "0.17.28" + [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -216,7 +222,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.4.2" [[deps.ClimaCore]] -deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.39" diff --git a/src/ClimaCore.jl b/src/ClimaCore.jl index b37bd8f397..72ada48e34 100644 --- a/src/ClimaCore.jl +++ b/src/ClimaCore.jl @@ -14,6 +14,7 @@ include("Topologies/Topologies.jl") include("Spaces/Spaces.jl") include("Fields/Fields.jl") include("Operators/Operators.jl") +include("MatrixFields/MatrixFields.jl") include("Hypsography/Hypsography.jl") include("Limiters/Limiters.jl") include("InputOutput/InputOutput.jl") diff --git a/src/Fields/mapreduce.jl b/src/Fields/mapreduce.jl index 1b9b058399..4ccf03dd49 100644 --- a/src/Fields/mapreduce.jl +++ b/src/Fields/mapreduce.jl @@ -1,4 +1,4 @@ -Base.map(fn, field::Field) = Base.broadcast(fn, field) +Base.map(fn, fields::Field...) = Base.broadcast(fn, fields...) """ Fields.local_sum(v::Field) diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index a66b723708..9c678d0697 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -269,10 +269,30 @@ Base.propertynames(x::AxisVector) = symbols(axes(x, 1)) end end +const AdjointAxisTensor{T, N, A, S} = Adjoint{T, AxisTensor{T, N, A, S}} + +Base.show(io::IO, a::AdjointAxisTensor{T, N, A, S}) where {T, N, A, S} = + print(io, "adjoint($(a'))") + +components(a::AdjointAxisTensor) = components(parent(a))' + +Base.zero(a::AdjointAxisTensor) = zero(a')' +Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} = + zero(AxisTensor{T, N, A, S})' + +@inline +(a::AdjointAxisTensor) = (+a')' +@inline -(a::AdjointAxisTensor) = (-a')' +@inline +(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' + b')' +@inline -(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' - b')' +@inline *(a::Number, b::AdjointAxisTensor) = (a * b')' +@inline *(a::AdjointAxisTensor, b::Number) = (a' * b)' +@inline /(a::AdjointAxisTensor, b::Number) = (a' / b)' +@inline \(a::Number, b::AdjointAxisTensor) = (a' \ b)' + +@inline (==)(a::AdjointAxisTensor, b::AdjointAxisTensor) = a' == b' const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}} -components(va::AdjointAxisVector) = components(parent(va))' Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) = getindex(components(va), i) Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) = @@ -286,7 +306,6 @@ Axis2Tensor( ) = AxisTensor(axes, components) const AdjointAxis2Tensor{T, A, S} = Adjoint{T, Axis2Tensor{T, A, S}} -components(va::AdjointAxis2Tensor) = components(parent(va))' const Axis2TensorOrAdj{T, A, S} = Union{Axis2Tensor{T, A, S}, AdjointAxis2Tensor{T, A, S}} diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl new file mode 100644 index 0000000000..2e6c8b63b1 --- /dev/null +++ b/src/MatrixFields/MatrixFields.jl @@ -0,0 +1,35 @@ +module MatrixFields + +import LinearAlgebra: UniformScaling, Adjoint +import StaticArrays: SArray, SMatrix, SVector +import BandedMatrices: BandedMatrix, band +import ..Utilities: PlusHalf, half +import ..RecursiveApply: rmap, rmaptype, rzero, radd, rsub, rmul, rdiv +import ..Geometry +import ..Spaces +import ..Fields +import ..Operators: + FiniteDifferenceOperator, + BoundaryCondition, + LeftBoundaryWindow, + RightBoundaryWindow, + has_boundary, + get_boundary, + stencil_interior_width, + boundary_width, + return_eltype, + return_space, + stencil_interior, + stencil_left_boundary, + stencil_right_boundary, + left_idx, + right_idx, + getidx, + getidx_args + +include("band_matrix_row.jl") +include("rmul_with_projection.jl") +include("matrix_multiplication.jl") +include("matrix_field_utils.jl") + +end diff --git a/src/MatrixFields/band_matrix_row.jl b/src/MatrixFields/band_matrix_row.jl new file mode 100644 index 0000000000..d4eb2200ff --- /dev/null +++ b/src/MatrixFields/band_matrix_row.jl @@ -0,0 +1,153 @@ +""" + BandMatrixRow{ld}(entries...) + +Stores the nonzero entries in a row of a band matrix, starting with the lowest +diagonal, which has index `ld`. Supported operations include accessing the entry +on the diagonal with index `d` by calling `row[d]`, taking linear combinations +with other band matrix rows, and checking for equality with other band matrix +rows. There are several aliases defined for commonly-used subtypes of +`BandMatrixRow` (with `T` denoting the type of the row's entries): +- `DiagonalMatrixRow{T}` +- `BidiagonalMatrixRow{T}` +- `TridiagonalMatrixRow{T}` +- `QuaddiagonalMatrixRow{T}` +- `PentadiagonalMatrixRow{T}` +""" +struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals) + entries::NTuple{bw, T} +end +BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} = + BandMatrixRow{ld, bw}(entries...) +function BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw} + promoted_entries = promote(entries...) + return BandMatrixRow{ld, bw, eltype(promoted_entries)}(promoted_entries) +end + +const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T} +const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T} +const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T} +const QuaddiagonalMatrixRow{T} = BandMatrixRow{-2 + half, 4, T} +const PentadiagonalMatrixRow{T} = BandMatrixRow{-2, 5, T} + +function Base.show( + io::IO, + ::Type{BMR}, +) where {ld, bw, T, BMR <: BandMatrixRow{ld, bw, T}} + string = if BMR <: DiagonalMatrixRow + "DiagonalMatrixRow{$T}" + elseif BMR <: BidiagonalMatrixRow + "BidiagonalMatrixRow{$T}" + elseif BMR <: TridiagonalMatrixRow + "TridiagonalMatrixRow{$T}" + elseif BMR <: QuaddiagonalMatrixRow + "QuaddiagonalMatrixRow{$T}" + elseif BMR <: PentadiagonalMatrixRow + "PentadiagonalMatrixRow{$T}" + else + "BandMatrixRow{$ld, $bw, $T}" + end + print(io, string) +end + +""" + outer_diagonals(::Type{<:BandMatrixRow}) + +Gets the indices of the lower and upper diagonals, `ld` and `ud`, of the given +subtype of `BandMatrixRow`. +""" +outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} = + (ld, ld + bw - 1) + +""" + band_matrix_row_type(ld, ud, T) + +A shorthand for getting the subtype of `BandMatrixRow` that has entries of type +`T` on the diagonals with indices in the range `ld:ud`. +""" +band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T} + +Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T + +Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = + BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...) + +Base.@propagate_inbounds Base.getindex(row::BandMatrixRow{ld}, d) where {ld} = + row.entries[d - ld + 1] + +function Base.promote_rule( + ::Type{BMR1}, + ::Type{BMR2}, +) where {BMR1 <: BandMatrixRow, BMR2 <: BandMatrixRow} + ld1, ud1 = outer_diagonals(BMR1) + ld2, ud2 = outer_diagonals(BMR2) + typeof(ld1) == typeof(ld2) || error( + "Cannot promote the $(ld1 isa PlusHalf ? "non-" : "")square matrix \ + row type $BMR1 and the $(ld2 isa PlusHalf ? "non-" : "")square \ + matrix row type $BMR2 to a common type", + ) + T = promote_type(eltype(BMR1), eltype(BMR2)) + return band_matrix_row_type(min(ld1, ld2), max(ud1, ud2), T) +end + +Base.promote_rule( + ::Type{BMR}, + ::Type{US}, +) where {BMR <: BandMatrixRow, US <: UniformScaling} = + promote_rule(BMR, DiagonalMatrixRow{eltype(US)}) + +function Base.convert( + ::Type{BMR}, + row::BandMatrixRow, +) where {BMR <: BandMatrixRow} + old_ld, old_ud = outer_diagonals(typeof(row)) + new_ld, new_ud = outer_diagonals(BMR) + typeof(old_ld) == typeof(new_ld) || + error("Cannot convert a $(old_ld isa PlusHalf ? "non-" : "")square \ + matrix row of type $(typeof(row)) to the \ + $(new_ld isa PlusHalf ? "non-" : "")square matrix row type $BMR") + new_ld <= old_ld && new_ud >= old_ud || + error("Cannot convert a $(typeof(row)) to a $BMR, since that would \ + require dropping potentially non-zero row entries") + first_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(old_ld - new_ld)) + entries = map(entry -> convert(eltype(BMR), entry), row.entries) + last_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(new_ud - old_ud)) + return BandMatrixRow{new_ld}(first_zeros..., entries..., last_zeros...) +end + +Base.convert(::Type{BMR}, row::UniformScaling) where {BMR <: BandMatrixRow} = + convert(BMR, DiagonalMatrixRow(row.λ)) + +Base.:(==)(row1::BMR, row2::BMR) where {BMR <: BandMatrixRow} = + row1.entries == row2.entries +Base.:(==)(row1::BandMatrixRow, row2::BandMatrixRow) = + ==(promote(row1, row2)...) +Base.:(==)(row1::BandMatrixRow, row2::UniformScaling) = + ==(promote(row1, row2)...) +Base.:(==)(row1::UniformScaling, row2::BandMatrixRow) = + ==(promote(row1, row2)...) + +Base.map(f::F, rows::BandMatrixRow{ld}...) where {F, ld} = + BandMatrixRow{ld}(map(f, map(row -> row.entries, rows)...)...) + +# Define all necessary operations for computing linear combinations. Use +# methods from RecursiveApply to handle nested types. +Base.:+(row::BandMatrixRow) = map(radd, row) +Base.:+(row1::BandMatrixRow, row2::BandMatrixRow) = + map(radd, promote(row1, row2)...) +Base.:+(row1::BandMatrixRow, row2::UniformScaling) = + map(radd, promote(row1, row2)...) +Base.:+(row1::UniformScaling, row2::BandMatrixRow) = + map(radd, promote(row1, row2)...) +Base.:-(row::BandMatrixRow) = map(rsub, row) +Base.:-(row1::BandMatrixRow, row2::BandMatrixRow) = + map(rsub, promote(row1, row2)...) +Base.:-(row1::BandMatrixRow, row2::UniformScaling) = + map(rsub, promote(row1, row2)...) +Base.:-(row1::UniformScaling, row2::BandMatrixRow) = + map(rsub, promote(row1, row2)...) +Base.:*(row::BandMatrixRow, value::Number) = + map(entry -> rmul(entry, value), row) +Base.:*(value::Number, row::BandMatrixRow) = + map(entry -> rmul(value, entry), row) +Base.:/(row::BandMatrixRow, value::Number) = + map(entry -> rdiv(entry, value), row) diff --git a/src/MatrixFields/matrix_field_utils.jl b/src/MatrixFields/matrix_field_utils.jl new file mode 100644 index 0000000000..2ee34ea877 --- /dev/null +++ b/src/MatrixFields/matrix_field_utils.jl @@ -0,0 +1,85 @@ +""" + column_field2array(field) + +Converts a field defined on a `FiniteDifferenceSpace` into a `Vector` or a +`BandedMatrix`, depending on whether or not the elements of the field are +`BandMatrixRow`s. This involves copying the data stored in the field. +""" +function column_field2array(field) + space = axes(field) + space isa Spaces.FiniteDifferenceSpace || + error("column_field2array requires a field on a FiniteDifferenceSpace") + n_rows = Spaces.nlevels(space) + if eltype(field) <: BandMatrixRow # field represents a matrix + field_ld, field_ud = outer_diagonals(eltype(field)) + + # Find the amount by which the field's diagonal indices get shifted when + # it is converted into a matrix, as well as the diagonal index of the + # value that ends up in the bottom-right corner of the matrix. + matrix_d_minus_field_d, bottom_corner_matrix_d = + if field_ld isa PlusHalf + if axes(field).staggering isa Spaces.CellCenter + half, 1 # field is a face-to-center matrix + else + -half, -1 # field is a center-to-face matrix + end + else + 0, 0 # field is either a center-to-center or face-to-face matrix + end + + matrix_ld = field_ld + matrix_d_minus_field_d + matrix_ud = field_ud + matrix_d_minus_field_d + matrix_ld <= 0 && matrix_ud >= 0 || + error("BandedMatrices.jl does not yet support matrices that have \ + diagonals with indices in the range $matrix_ld:$matrix_ud") + + n_cols = n_rows + bottom_corner_matrix_d + matrix = BandedMatrix{eltype(eltype(field))}( + undef, + (n_rows, n_cols), + (-matrix_ld, matrix_ud), + ) + for (index_of_field_entry, matrix_d) in enumerate(matrix_ld:matrix_ud) + # Find the rows for which field_diagonal[row] is inside the matrix. + # Note: The matrix index (1, 1) corresponds to the diagonal index 0, + # and the matrix index (n_rows, n_cols) corresponds to the diagonal + # index bottom_corner_matrix_d. + first_row = matrix_d < 0 ? 1 - matrix_d : 1 + last_row = + matrix_d < bottom_corner_matrix_d ? n_rows : n_cols - matrix_d + + # Copy the value in each row from field_diagonal to matrix_diagonal. + field_diagonal = field.entries.:($index_of_field_entry) + matrix_diagonal = view(matrix, band(matrix_d)) + for (index_along_diagonal, row) in enumerate(first_row:last_row) + matrix_diagonal[index_along_diagonal] = + Fields.field_values(field_diagonal)[row] + end + end + return matrix + else # field represents a vector + return map(i -> Fields.field_values(field)[i], 1:n_rows) + end +end + +""" + field2arrays(field) + +Converts a field defined on a `FiniteDifferenceSpace` or on an +`ExtrudedFiniteDifferenceSpace` into a tuple of arrays, each of which +corresponds to a column of the field. This is done by calling +`column_field2array` on each of the field's columns. +""" +function field2arrays(field) + space = axes(field) + column_indices = if space isa Spaces.FiniteDifferenceSpace + (((1, 1), 1),) + elseif space isa Spaces.ExtrudedFiniteDifferenceSpace + (Spaces.all_nodes(Spaces.horizontal_space(space))...,) + else + error("Invalid space type: $(typeof(space).name.wrapper)") + end + return map(column_indices) do ((i, j), h) + return column_field2array(Spaces.column(field, i, j, h)) + end +end diff --git a/src/MatrixFields/matrix_multiplication.jl b/src/MatrixFields/matrix_multiplication.jl new file mode 100644 index 0000000000..f5b1e6f920 --- /dev/null +++ b/src/MatrixFields/matrix_multiplication.jl @@ -0,0 +1,338 @@ +""" + MultiplyColumnwiseBandMatrixField + +An operator that multiplies a columnwise band matrix field (a field of +`BandMatrixRow`s) by a regular field or by another columnwise band matrix field, +i.e., matrix-vector or matrix-matrix multiplication. The `⋅` symbol is an alias +for `MultiplyColumnwiseBandMatrixField()`. +""" +struct MultiplyColumnwiseBandMatrixField <: FiniteDifferenceOperator end +const ⋅ = MultiplyColumnwiseBandMatrixField() + +#= +TODO: Rewrite the following derivation in LaTeX and move it into the ClimaCore +documentation. + +Notation: + +For any single-column field F, let F[idx] denote the value of F at level idx. +For any single-column BandMatrixRow field M, let + M[idx, idx′] = M[idx][idx′ - idx]. +If there are multiple columns, the following equations apply per column. + +Matrix-Vector Multiplication: + +Consider a BandMatrixRow field M and a scalar (non-BandMatrixRow) field V. +From the definition of matrix-vector multiplication, + (M ⋅ V)[idx] = ∑_{idx′} M[idx, idx′] * V[idx′]. +If V[idx] is only defined when left_idx ≤ idx ≤ right_idx, this becomes + (M ⋅ V)[idx] = ∑_{idx′ ∈ left_idx:right_idx} M[idx, idx′] * V[idx′]. +If M[idx, idx′] is only defined when idx + ld ≤ idx′ ≤ idx + ud, this becomes + (M ⋅ V)[idx] = + ∑_{idx′ ∈ max(left_idx, idx + ld):min(right_idx, idx + ud)} + M[idx, idx′] * V[idx′]. +Replacing the variable idx′ with the variable d = idx′ - idx gives us + (M ⋅ V)[idx] = + ∑_{d ∈ max(left_idx - idx, ld):min(right_idx - idx, ud)} + M[idx, idx + d] * V[idx + d]. +This can be rewritten using the standard indexing notation as + (M ⋅ V)[idx] = + ∑_{d ∈ max(left_idx - idx, ld):min(right_idx - idx, ud)} + M[idx][d] * V[idx + d]. +Finally, we can express this in terms of left/right boundaries and an interior: + (M ⋅ V)[idx] = + ∑_{ + d ∈ + if idx < left_idx - ld + (left_idx - idx):ud + elseif idx > right_idx - ud + ld:(right_idx - idx) + else + ld:ud + end + } M[idx][d] * V[idx + d]. + +Matrix-Matrix Multiplication: + +Consider a BandMatrixRow field M1 and another BandMatrixRow field M2. +From the definition of matrix-matrix multiplication, + (M1 ⋅ M2)[idx, idx′] = ∑_{idx′′} M1[idx, idx′′] * M2[idx′′, idx′]. +If M2[idx′′] is only defined when left_idx ≤ idx′′ ≤ right_idx, this becomes + (M1 ⋅ M2)[idx, idx′] = + ∑_{idx′′ ∈ left_idx:right_idx} M1[idx, idx′′] * M2[idx′′, idx′]. +If M1[idx, idx′′] is only defined when idx + ld1 ≤ idx′′ ≤ idx + ud1, this becomes + (M1 ⋅ M2)[idx, idx′] = + ∑_{idx′′ ∈ max(left_idx, idx + ld1):min(right_idx, idx + ud1)} + M1[idx, idx′′] * M2[idx′′, idx′]. +If M2[idx′′, idx′] is only defined when idx′′ + ld2 ≤ idx′ ≤ idx′′ + ud2, or, +equivalently, when idx′ - ud2 ≤ idx′′ ≤ idx′ - ld2, this becomes + (M1 ⋅ M2)[idx, idx′] = + ∑_{ + idx′′ ∈ + max(left_idx, idx + ld1, idx′ - ud2): + min(right_idx, idx + ud1, idx′ - ld2) + } M1[idx, idx′′] * M2[idx′′, idx′]. +Replacing the variable idx′ with the variable prod_d = idx′ - idx gives us + (M1 ⋅ M2)[idx, idx + prod_d] = + ∑_{ + idx′′ ∈ + max(left_idx, idx + ld1, idx + prod_d - ud2): + min(right_idx, idx + ud1, idx + prod_d - ld2) + } M1[idx, idx′′] * M2[idx′′, idx + prod_d]. +Replacing the variable idx′′ with the variable d = idx′′ - idx gives us + (M1 ⋅ M2)[idx, idx + prod_d] = + ∑_{ + d ∈ + max(left_idx - idx, ld1, prod_d - ud2): + min(right_idx - idx, ud1, prod_d - ld2) + } M1[idx, idx + d] * M2[idx + d, idx + prod_d]. +This can be rewritten using the standard indexing notation as + (M1 ⋅ M2)[idx][prod_d] = + ∑_{ + d ∈ + max(left_idx - idx, ld1, prod_d - ud2): + min(right_idx - idx, ud1, prod_d - ld2) + } M1[idx][d] * M2[idx + d][prod_d - d]. +Finally, we can express this in terms of left/right boundaries and an interior: + (M1 ⋅ M2)[idx][prod_d] = + ∑_{ + d ∈ + if idx < left_idx - ld1 + max(left_idx - idx, prod_d - ud2):min(ud1, prod_d - ld2) + elseif idx > right_idx - ud1 + max(ld1, prod_d - ud2):min(right_idx - idx, prod_d - ld2) + else + max(ld1, prod_d - ud2):min(ud1, prod_d - ld2) + end + } M1[idx][d] * M2[idx + d][prod_d - d]. + +We only need to define (M1 ⋅ M2)[idx][prod_d] when it has a nonzero value in the +interior, which will be the case when + max(ld1, prod_d - ud2) ≤ min(ud1, prod_d - ld2). +This can be rewritten as a system of four inequalities: + ld1 ≤ ud1, + ld1 ≤ prod_d - ld2, + prod_d - ud2 ≤ ud1, and + prod_d - ud2 ≤ prod_d - ld2. +By definition, ld1 ≤ ud1 and ld2 ≤ ud2, so the first and last inequality are +always true. Rearranging the remaining two inequalities gives us + ld1 + ld2 ≤ prod_d ≤ ud1 + ud2. +=# + +struct TopLeftMatrixCorner <: BoundaryCondition end +struct BottomRightMatrixCorner <: BoundaryCondition end + +has_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::LeftBoundaryWindow{name}, +) where {name} = true +has_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::RightBoundaryWindow{name}, +) where {name} = true + +get_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::LeftBoundaryWindow{name}, +) where {name} = TopLeftMatrixCorner() +get_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::RightBoundaryWindow{name}, +) where {name} = BottomRightMatrixCorner() + +stencil_interior_width(::MultiplyColumnwiseBandMatrixField, matrix1, arg) = + ((0, 0), outer_diagonals(eltype(matrix1))) + +function boundary_width( + ::MultiplyColumnwiseBandMatrixField, + ::TopLeftMatrixCorner, + matrix1, + arg, +) + ld1, _ = outer_diagonals(eltype(matrix1)) + return max((left_idx(axes(arg)) - ld1) - left_idx(axes(matrix1)), 0) +end +function boundary_width( + ::MultiplyColumnwiseBandMatrixField, + ::BottomRightMatrixCorner, + matrix1, + arg, +) + _, ud1 = outer_diagonals(eltype(matrix1)) + return max(right_idx(axes(matrix1)) - (right_idx(axes(arg)) - ud1), 0) +end + +function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg) + eltype(matrix1) <: BandMatrixRow || + error("The first argument of ⋅ must be a band matrix field, but the \ + given argument is a field with element type $(eltype(matrix1))") + if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication + matrix2 = arg + ld1, ud1 = outer_diagonals(eltype(matrix1)) + ld2, ud2 = outer_diagonals(eltype(matrix2)) + prod_ld, prod_ud = ld1 + ld2, ud1 + ud2 + prod_eltype = rmul_with_projection_return_type( + eltype(eltype(matrix1)), + eltype(eltype(matrix2)), + ) + return band_matrix_row_type(prod_ld, prod_ud, prod_eltype) + else # matrix-vector multiplication + vector = arg + return rmul_with_projection_return_type( + eltype(eltype(matrix1)), + eltype(vector), + ) + end +end + +return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) = space1 + +# TODO: Use @propagate_inbounds here, and remove @inbounds from this function. +# As of Julia 1.8, doing this increases compilation by more than an order of +# magnitude, and it also makes type inference fail for some complicated matrix +# field broadcast expressions. Unfortunately, not using @propagate_inbounds +# makes matrix field broadcast expressions take roughly 3 times longer to +# evaluate. However, since they are sufficiently fast as is, this is an +# acceptable performance loss. +function multiply_columnwise_band_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, + boundary_ld1 = nothing, + boundary_ud1 = nothing, +) + matrix1_row = getidx(space, matrix1, loc, idx, hidx) + ld1, ud1 = outer_diagonals(eltype(matrix1)) + ld1_or_boundary_ld1 = isnothing(boundary_ld1) ? ld1 : boundary_ld1 + ud1_or_boundary_ud1 = isnothing(boundary_ud1) ? ud1 : boundary_ud1 + return_type = return_eltype(⋅, matrix1, arg) + if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication + matrix2 = arg + matrix2_rows = BandMatrixRow{ld1}( + map((ld1:ud1...,)) do d + # TODO: Use @propagate_inbounds_meta instead of @inline_meta. + Base.@_inline_meta + if ( + (isnothing(boundary_ld1) || d >= boundary_ld1) && + (isnothing(boundary_ud1) || d <= boundary_ud1) + ) + return @inbounds getidx(space, matrix2, loc, idx + d, hidx) + else + return rzero(eltype(matrix2)) # This value is never used. + end + end..., + ) # The rows are precomputed to avoid recomputing them multiple times. + ld2, ud2 = outer_diagonals(eltype(matrix2)) + prod_ld, prod_ud = outer_diagonals(return_type) + zero_value = rzero(eltype(return_type)) + prod_entries = map((prod_ld:prod_ud...,)) do prod_d + # TODO: Use @propagate_inbounds_meta instead of @inline_meta. + Base.@_inline_meta + min_d = max(ld1_or_boundary_ld1, prod_d - ud2) + max_d = min(ud1_or_boundary_ud1, prod_d - ld2) + # Note: If min_d:max_d is an empty range, then the current entry + # lies outside of the product matrix, so it should never be used in + # any computations. By initializing prod_entry to zero_value, we are + # implicitly setting all such entries to 0. We could alternatively + # set all such entries to NaN (in order to more easily catch user + # errors that involve accidentally using these entires), but that + # would not generalize to non-floating-point types like Int or Bool. + prod_entry = zero_value + @inbounds for d in min_d:max_d + value1 = matrix1_row[d] + value2 = matrix2_rows[d][prod_d - d] + value2_lg = Geometry.LocalGeometry(space, idx + d, hidx) + prod_entry = radd( + prod_entry, + rmul_with_projection(value1, value2, value2_lg), + ) + end # Using this for-loop is currently faster than using mapreduce. + return prod_entry + end + return BandMatrixRow{prod_ld}(prod_entries...) + else # matrix-vector multiplication + vector = arg + prod_value = rzero(return_type) + @inbounds for d in ld1_or_boundary_ld1:ud1_or_boundary_ud1 + value1 = matrix1_row[d] + value2 = getidx(space, vector, loc, idx + d, hidx) + value2_lg = Geometry.LocalGeometry(space, idx + d, hidx) + prod_value = radd( + prod_value, + rmul_with_projection(value1, value2, value2_lg), + ) + end # Using this for-loop is currently faster than using mapreduce. + return prod_value + end +end + +Base.@propagate_inbounds stencil_interior( + ::MultiplyColumnwiseBandMatrixField, + loc, + space, + idx, + hidx, + matrix1, + arg, +) = multiply_columnwise_band_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, +) + +Base.@propagate_inbounds stencil_left_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::TopLeftMatrixCorner, + loc, + space, + idx, + hidx, + matrix1, + arg, +) = multiply_columnwise_band_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, + left_idx(axes(arg)) - idx, + nothing, +) + +Base.@propagate_inbounds stencil_right_boundary( + ::MultiplyColumnwiseBandMatrixField, + ::BottomRightMatrixCorner, + loc, + space, + idx, + hidx, + matrix1, + arg, +) = multiply_columnwise_band_matrix_at_index( + loc, + space, + idx, + hidx, + matrix1, + arg, + nothing, + right_idx(axes(arg)) - idx, +) + +# For complex matrix field broadcast expressions involving 4 or more matrices, +# we sometimes hit a recursion limit and de-optimize. +# We know the recursion will terminate due to the fact that broadcast +# expressions are not self-referential (aside from pathological examples). +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(multiply_columnwise_band_matrix_at_index) + m.recursion_relation = dont_limit + end +end diff --git a/src/MatrixFields/rmul_with_projection.jl b/src/MatrixFields/rmul_with_projection.jl new file mode 100644 index 0000000000..43c6b620fc --- /dev/null +++ b/src/MatrixFields/rmul_with_projection.jl @@ -0,0 +1,118 @@ +const SingleValue = + Union{Number, Geometry.AxisTensor, Geometry.AdjointAxisTensor} + +mul_with_projection(x, y, _) = x * y +mul_with_projection(x::Geometry.AdjointAxisVector, y::Geometry.AxisTensor, lg) = + x * Geometry.project(Geometry.dual(axes(x, 2)), y, lg) +mul_with_projection(::Geometry.AdjointAxisTensor, ::Geometry.AxisTensor, _) = + error("mul_with_projection is currently only implemented for covectors, \ + and higher-order cotensors are not supported") +# We should add methods for other cotensors (e.g., AdjointAxis2Tensor) when they +# are needed (e.g., when we need to support matrices that represent the +# divergence of higher-order tensors). + +""" + rmul_with_projection(x, y, lg) + +Similar to `rmul(x, y)`, but with automatic projection of `y` when `x` contains +a covector (i.e, an `AdjointAxisVector`). For example, if `x` is a covector +along the `Covariant3Axis` (e.g., `Covariant3Vector(1)'`), then `y` (or each +element of `y`) will be projected onto the `Contravariant3Axis`. In general, `y` +(or each element of `y`) will be projected onto the dual axis of each covector +in `x`. In the future, we may extend this behavior to higher-order cotensors. +""" +rmul_with_projection(x, y, lg) = + rmap((x′, y′) -> mul_with_projection(x′, y′, lg), x, y) +rmul_with_projection(x::SingleValue, y, lg) = + rmap(y′ -> mul_with_projection(x, y′, lg), y) +rmul_with_projection(x, y::SingleValue, lg) = + rmap(x′ -> mul_with_projection(x′, y, lg), x) +rmul_with_projection(x::SingleValue, y::SingleValue, lg) = + mul_with_projection(x, y, lg) + +function number_times_tensor_return_type( + ::Type{T1}, + ::Type{Geometry.AxisTensor{T2, N, ATuple, SArray{STuple, T2, N, L}}}, +) where {T1, T2, N, ATuple, STuple, L} + T = promote_type(T1, T2) + return Geometry.AxisTensor{T, N, ATuple, SArray{STuple, T, N, L}} +end + +function number_times_cotensor_return_type( + ::Type{T1}, + ::Type{Geometry.AdjointAxisTensor{T2, N, ATuple, SArray{STuple, T2, N, L}}}, +) where {T1, T2, N, ATuple, STuple, L} + T = promote_type(T1, T2) + return Geometry.AdjointAxisTensor{T, N, ATuple, SArray{STuple, T, N, L}} +end + +covector_times_tensor_return_type( + ::Type{<:Geometry.AdjointAxisVector{T1}}, + ::Type{<:Geometry.AxisVector{T2}}, +) where {T1, T2} = promote_type(T1, T2) + +function covector_times_tensor_return_type( + ::Type{<:Geometry.AdjointAxisVector{T1}}, + ::Type{<:Geometry.Axis2Tensor{T2, ATuple, <:SMatrix{<:Any, S, T2}}}, +) where {T1, T2, A, S, ATuple <: Tuple{<:Any, A}} + T = promote_type(T1, T2) + return Geometry.AdjointAxisVector{T, A, SVector{S, T}} +end + +mul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: Number, Y <: Number} = promote_type(X, Y) +mul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: Number, Y <: Geometry.AxisTensor} = + number_times_tensor_return_type(X, Y) +mul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: Geometry.AxisTensor, Y <: Number} = + number_times_tensor_return_type(Y, X) +mul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: Number, Y <: Geometry.AdjointAxisTensor} = + number_times_cotensor_return_type(X, Y) +mul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: Geometry.AdjointAxisTensor, Y <: Number} = + number_times_cotensor_return_type(Y, X) +mul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: Geometry.AdjointAxisVector, Y <: Geometry.AxisTensor} = + covector_times_tensor_return_type(X, Y) +mul_with_projection_return_type(::Type{X}, ::Type{Y}) where {X, Y} = + error("mul_with_projection_return_type is not yet implemented for types \ + $X and $Y") +# Add methods for other combinations of types if needed. + +""" + rmul_with_projection_return_type(X, Y) + +The return type of `rmul_with_projection(x, y, lg)`, where `x` has type `X` and +`y` has type `Y`. +""" +rmul_with_projection_return_type(::Type{X}, ::Type{Y}) where {X, Y} = + rmaptype(mul_with_projection_return_type, X, Y) +rmul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: SingleValue, Y} = + rmaptype(Y′ -> mul_with_projection_return_type(X, Y′), Y) +rmul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X, Y <: SingleValue} = + rmaptype(X′ -> mul_with_projection_return_type(X′, Y), X) +rmul_with_projection_return_type( + ::Type{X}, + ::Type{Y}, +) where {X <: SingleValue, Y <: SingleValue} = + mul_with_projection_return_type(X, Y) diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index dfe1907ac8..31add2b7d5 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -9,6 +9,51 @@ module RecursiveApply export ⊞, ⊠, ⊟ +# These functions need to be generated for type stability (since T.parameters is +# a SimpleVector, the compiler cannot always infer its size and elements). +has_no_params(::Type{T}) where {T} = + if @generated + :($(isempty(T.parameters))) + else + isempty(T.parameters) + end +first_param(::Type{T}) where {T} = + if @generated + :($(first(T.parameters))) + else + first(T.parameters) + end +tail_params(::Type{T}) where {T} = + if @generated + :($(Tuple{Base.tail((T.parameters...,))...})) + else + Tuple{Base.tail((T.parameters...,))...} + end + +# This is a type-stable version of map(x -> rmap(fn, x), X) or +# map((x, y) -> rmap(fn, x, y), X, Y). +rmap_tuple(fn::F, X) where {F} = + isempty(X) ? () : (rmap(fn, first(X)), rmap_tuple(fn, Base.tail(X))...) +rmap_tuple(fn::F, X, Y) where {F} = + isempty(X) || isempty(Y) ? () : + ( + rmap(fn, first(X), first(Y)), + rmap_tuple(fn, Base.tail(X), Base.tail(Y))..., + ) + +# This is a type-stable version of map(T′ -> rfunc(fn, T′), T.parameters) or +# map((T1′, T2′) -> rfunc(fn, T1′, T2′), T1.parameters, T2.parameters), where +# rfunc can be either rmaptype or rmap_type2value. +rmap_Tuple(rfunc::R, fn::F, ::Type{T}) where {R, F, T} = + has_no_params(T) ? () : + (rfunc(fn, first_param(T)), rmap_Tuple(rfunc, fn, tail_params(T))...) +rmap_Tuple(rfunc::R, fn::F, ::Type{T1}, ::Type{T2}) where {R, F, T1, T2} = + has_no_params(T1) || has_no_params(T2) ? () : + ( + rfunc(fn, first_param(T1), first_param(T2)), + rmap_Tuple(rfunc, fn, tail_params(T1), tail_params(T2))..., + ) + """ rmap(fn, X...) @@ -16,10 +61,8 @@ Recursively apply `fn` to each element of `X` """ rmap(fn::F, X) where {F} = fn(X) rmap(fn::F, X, Y) where {F} = fn(X, Y) -rmap(fn::F, X::Tuple) where {F} = map(x -> rmap(fn, x), X) -rmap(fn, X::Tuple{}, Y::Tuple{}) = () -rmap(fn::F, X::Tuple, Y::Tuple) where {F} = - (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) +rmap(fn::F, X::Tuple) where {F} = rmap_tuple(fn, X) +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = rmap_tuple(fn, X, Y) rmap(fn::F, X::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X))) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = @@ -32,17 +75,53 @@ rmax(X, Y) = rmap(max, X, Y) """ rmaptype(fn, T) + rmaptype(fn, T1, T2) -The return type of `rmap(fn, X::T)`. +Recursively apply `fn` to each type parameter of the type `T`, or to each type +parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = - Tuple{map(fn, tuple(T.parameters...))...} + Tuple{rmap_Tuple(rmaptype, fn, T)...} +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = + Tuple{rmap_Tuple(rmaptype, fn, T1, T2)...} +rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = + NamedTuple{names, rmaptype(fn, Tup)} rmaptype( + fn::F, + ::Type{T1}, + ::Type{T2}, +) where { + F, + names, + Tup1, + Tup2, + T1 <: NamedTuple{names, Tup1}, + T2 <: NamedTuple{names, Tup2}, +} = NamedTuple{names, rmaptype(fn, Tup1, Tup2)} + +""" + rmap_type2value(fn, T) + +Recursively apply `fn` to each type parameter of the type `T`, where `fn` +returns a value instead of a type. +""" +rmap_type2value(fn::F, ::Type{T}) where {F, T} = fn(T) +rmap_type2value(fn::F, ::Type{T}) where {F, T <: Tuple} = + rmap_Tuple(rmap_type2value, fn, T) +rmap_type2value( fn::F, ::Type{T}, -) where {F, T <: NamedTuple{names, tup}} where {names, tup} = - NamedTuple{names, rmaptype(fn, tup)} +) where {F, names, Tup, T <: NamedTuple{names, Tup}} = + NamedTuple{names}(rmap_type2value(fn, Tup)) + +""" + rzero(T) + +Recursively compute the zero value of type `T`. +""" +rzero(::Type{T}) where {T} = rmap_type2value(zero, T) """ rmul(X, Y) diff --git a/test/MatrixFields/matrix_field_broadcasting.jl b/test/MatrixFields/matrix_field_broadcasting.jl new file mode 100644 index 0000000000..485beb86f6 --- /dev/null +++ b/test/MatrixFields/matrix_field_broadcasting.jl @@ -0,0 +1,762 @@ +using Test +using JET +using Random: seed! +using LinearAlgebra: I, mul! +using BandedMatrices: band + +import Profile +import PProf + +import ClimaCore: + Geometry, Domains, Meshes, Topologies, Hypsography, Spaces, Fields +import ClimaCore.MatrixFields: + DiagonalMatrixRow, + BidiagonalMatrixRow, + TridiagonalMatrixRow, + QuaddiagonalMatrixRow, + field2arrays, + ⋅ +import ClimaCore.Utilities: PlusHalf, half +import ClimaComms + +# Using @benchmark from BenchmarkTools is extremely slow; it appears to keep +# triggering recompilations and allocating a lot of memory in the process. +# This macro returns the minimum time (in seconds) required to run the +# expression after it has been compiled, as well as the amount of memory (in +# bytes) allocated when running the expression. +macro benchmark(expression) + return quote + $(esc(expression)) # Compile the expression first. Use esc for hygiene. + min_time = Inf + allocs = 0 + start_time = time_ns() + while time_ns() - start_time < 100000000 # Run the benchmark for 0.1 s. + timed_result = @timed $(esc(expression)) + min_time = min(min_time, timed_result.time) + allocs = timed_result.bytes # This value should not change. + end + (min_time, allocs) + end +end + +function call_func(func!::F, result, inputs) where {F} + func!(result, inputs...) + return nothing +end + +call_array_func( + func!::F, + result_arrays, + inputs_arrays, + temp_values_arrays, +) where {F} = + foreach(func!, result_arrays, inputs_arrays..., temp_values_arrays...) + +function test_func_against_array_reference(; + test_name, + result, + inputs, + temp_values, + func!::F1, + ref_func!::F2, + print_summary = true, + generate_profile = false, +) where {F1, F2} + @testset "$test_name" begin + # Fill all output fields with NaNs for testing. + result .*= NaN + for temp_value in temp_values + temp_value .*= NaN + end + + ref_result_arrays = field2arrays(result) + inputs_arrays = map(field2arrays, inputs) + temp_values_arrays = map(field2arrays, temp_values) + + time, allocs = @benchmark call_func(func!, result, inputs) + ref_time, ref_allocs = @benchmark call_array_func( + ref_func!, + ref_result_arrays, + inputs_arrays, + temp_values_arrays, + ) + + # Compute the maximum error as an integer multiple of epsilon. + result_arrays = field2arrays(result) + maximum_error = + maximum(zip(result_arrays, ref_result_arrays)) do (array, ref_array) + maximum(zip(array, ref_array)) do (value, ref_value) + Int(abs(value - ref_value) / eps(ref_value)) + end + end + + if print_summary + @info "$test_name:\n\tBest Time = $time s\n\tBest Reference Time = \ + $ref_time s\n\tMaximum Error = $maximum_error * eps" + end + + if generate_profile + Profile.clear() + Profile.@profile @benchmark call_func(func!, result, inputs) + PProf.pprof() + end + + # Check for correctness. + @test maximum_error <= 3 + + # Check for performance. + @test time < ref_time + + # Check for allocations. + @test allocs == 0 + @test ref_allocs == 0 + + # Check for type instabilities. + @test_opt call_func(func!, result, inputs) + @test_opt call_array_func( + ref_func!, + ref_result_arrays, + inputs_arrays, + temp_values_arrays, + ) + + # The reference function needs to be checked for allocations and type + # instabilities to ensure that the performance comparison is fair. + end +end + +function test_func_against_reference(; + test_name, + result, + inputs, + ref_inputs, + func!::F1, + ref_func!::F2, + print_summary = true, + generate_profile = false, +) where {F1, F2} + @testset "$test_name" begin + result .*= NaN + ref_result = copy(result) + + time, allocs = @benchmark call_func(func!, result, inputs) + ref_time, ref_allocs = + @benchmark call_func(ref_func!, ref_result, ref_inputs) + + if print_summary + @info "$test_name:\n\tBest Time = $time s\n\tBest Reference Time = \ + $ref_time s" + end + + if generate_profile + Profile.clear() + Profile.@profile @benchmark call_func(func!, result, inputs) + PProf.pprof() + end + + # Check for correctness. + @test result == ref_result + + # Check for performance. + @test time < ref_time + + # Check for allocations. + @test allocs == 0 + @test ref_allocs == 0 + + # Check for type instabilities. + @test_opt call_func(func!, result, inputs) + @test_opt call_func(ref_func!, ref_result, ref_inputs) + + # The reference function needs to be checked for allocations and type + # instabilities to ensure that the performance comparison is fair. + end +end + +function random_test_fields(::Type{FT}) where {FT} + velem = 20 # This should be big enough to test high-bandwidth matrices. + helem = npoly = 1 # These should be small enough for the tests to be fast. + + hdomain = Domains.SphereDomain(FT(10)) + hmesh = Meshes.EquiangularCubedSphere(hdomain, helem) + htopology = Topologies.Topology2D(ClimaComms.SingletonCommsContext(), hmesh) + quad = Spaces.Quadratures.GLL{npoly + 1}() + hspace = Spaces.SpectralElementSpace2D(htopology, quad) + vdomain = Domains.IntervalDomain( + Geometry.ZPoint(FT(0)), + Geometry.ZPoint(FT(10)); + boundary_tags = (:bottom, :top), + ) + vmesh = Meshes.IntervalMesh(vdomain, nelems = velem) + vspace = Spaces.CenterFiniteDifferenceSpace(vmesh) + sfc_coord = Fields.coordinate_field(hspace) + hypsography = Hypsography.LinearAdaption( + @. cosd(sfc_coord.lat) + cosd(sfc_coord.long) + 1 + ) + center_space = + Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace, hypsography) + face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(center_space) + ᶜcoord = Fields.coordinate_field(center_space) + ᶠcoord = Fields.coordinate_field(face_space) + + seed!(1) # ensures reproducibility + ᶜᶜmat = map(_ -> DiagonalMatrixRow(rand(FT, 1)...), ᶜcoord) + ᶜᶠmat = map(_ -> BidiagonalMatrixRow(rand(FT, 2)...), ᶜcoord) + ᶠᶠmat = map(_ -> TridiagonalMatrixRow(rand(FT, 3)...), ᶠcoord) + ᶠᶜmat = map(_ -> QuaddiagonalMatrixRow(rand(FT, 4)...), ᶠcoord) + ᶜvec = map(_ -> rand(FT), ᶜcoord) + ᶠvec = map(_ -> rand(FT), ᶠcoord) + + return ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec +end + +@testset "Scalar Matrix Field Broadcasting" begin + ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec = random_test_fields(Float64) + + test_func_against_array_reference(; + test_name = "diagonal matrix times vector", + result = (@. ᶜᶜmat ⋅ ᶜvec), + inputs = (ᶜᶜmat, ᶜvec), + temp_values = (), + func! = (result, ᶜᶜmat, ᶜvec) -> (@. result = ᶜᶜmat ⋅ ᶜvec), + ref_func! = (_result, _ᶜᶜmat, _ᶜvec) -> mul!(_result, _ᶜᶜmat, _ᶜvec), + ) + + test_func_against_array_reference(; + test_name = "tri-diagonal matrix times vector", + result = (@. ᶠᶠmat ⋅ ᶠvec), + inputs = (ᶠᶠmat, ᶠvec), + temp_values = (), + func! = (result, ᶠᶠmat, ᶠvec) -> (@. result = ᶠᶠmat ⋅ ᶠvec), + ref_func! = (_result, _ᶠᶠmat, _ᶠvec) -> mul!(_result, _ᶠᶠmat, _ᶠvec), + ) + + test_func_against_array_reference(; + test_name = "quad-diagonal matrix times vector", + result = (@. ᶠᶜmat ⋅ ᶜvec), + inputs = (ᶠᶜmat, ᶜvec), + temp_values = (), + func! = (result, ᶠᶜmat, ᶜvec) -> (@. result = ᶠᶜmat ⋅ ᶜvec), + ref_func! = (_result, _ᶠᶜmat, _ᶜvec) -> mul!(_result, _ᶠᶜmat, _ᶜvec), + ) + + test_func_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix", + result = (@. ᶜᶜmat ⋅ ᶜᶠmat), + inputs = (ᶜᶜmat, ᶜᶠmat), + temp_values = (), + func! = (result, ᶜᶜmat, ᶜᶠmat) -> (@. result = ᶜᶜmat ⋅ ᶜᶠmat), + ref_func! = (_result, _ᶜᶜmat, _ᶜᶠmat) -> mul!(_result, _ᶜᶜmat, _ᶜᶠmat), + ) + + test_func_against_array_reference(; + test_name = "tri-diagonal matrix times tri-diagonal matrix", + result = (@. ᶠᶠmat ⋅ ᶠᶠmat), + inputs = (ᶠᶠmat,), + temp_values = (), + func! = (result, ᶠᶠmat) -> (@. result = ᶠᶠmat ⋅ ᶠᶠmat), + ref_func! = (_result, _ᶠᶠmat) -> mul!(_result, _ᶠᶠmat, _ᶠᶠmat), + ) + + test_func_against_array_reference(; + test_name = "quad-diagonal matrix times diagonal matrix", + result = (@. ᶠᶜmat ⋅ ᶜᶜmat), + inputs = (ᶠᶜmat, ᶜᶜmat), + temp_values = (), + func! = (result, ᶠᶜmat, ᶜᶜmat) -> (@. result = ᶠᶜmat ⋅ ᶜᶜmat), + ref_func! = (_result, _ᶠᶜmat, _ᶜᶜmat) -> mul!(_result, _ᶠᶜmat, _ᶜᶜmat), + ) + + test_func_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix", + result = (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ((@. ᶜᶜmat ⋅ ᶜᶠmat), (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat)), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. result = ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ ᶠᶜmat))), + ref_func! = (_result, _ᶜᶜmat, _ᶜᶠmat, _ᶠᶠmat, _ᶠᶜmat, _temp1, _temp2) -> + begin + mul!(_temp1, _ᶜᶜmat, _ᶜᶠmat) + mul!(_temp2, _temp1, _ᶠᶠmat) + mul!(_result, _temp2, _ᶠᶜmat) + end, + ) + + test_func_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix, but with \ + forced right-associativity", + result = (@. ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ ᶠᶜmat))), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ((@. ᶠᶠmat ⋅ ᶠᶜmat), (@. ᶜᶠmat ⋅ (ᶠᶠmat ⋅ ᶠᶜmat))), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. result = ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat), + ref_func! = (_result, _ᶜᶜmat, _ᶜᶠmat, _ᶠᶠmat, _ᶠᶜmat, _temp1, _temp2) -> + begin + mul!(_temp1, _ᶠᶠmat, _ᶠᶜmat) + mul!(_temp2, _ᶜᶠmat, _temp1) + mul!(_result, _ᶜᶜmat, _temp2) + end, + ) + + test_func_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix times \ + vector", + result = (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat ⋅ ᶜvec), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec), + temp_values = ( + (@. ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat), + (@. ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec) -> + (@. result = ᶜᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶠmat ⋅ ᶠᶜmat ⋅ ᶜvec), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _ᶜvec, + _temp1, + _temp2, + _temp3, + ) -> begin + mul!(_temp1, _ᶜᶜmat, _ᶜᶠmat) + mul!(_temp2, _temp1, _ᶠᶠmat) + mul!(_temp3, _temp2, _ᶠᶜmat) + mul!(_result, _temp3, _ᶜvec) + end, + ) + + test_func_against_array_reference(; + test_name = "diagonal matrix times bi-diagonal matrix times \ + tri-diagonal matrix times quad-diagonal matrix times \ + vector, but with forced right-associativity", + result = (@. ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec)))), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec), + temp_values = ( + (@. ᶠᶜmat ⋅ ᶜvec), + (@. ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec)), + (@. ᶜᶠmat ⋅ (ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec))), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec) -> + (@. result = ᶜᶜmat ⋅ (ᶜᶠmat ⋅ (ᶠᶠmat ⋅ (ᶠᶜmat ⋅ ᶜvec)))), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _ᶜvec, + _temp1, + _temp2, + _temp3, + ) -> begin + mul!(_temp1, _ᶠᶜmat, _ᶜvec) + mul!(_temp2, _ᶠᶠmat, _temp1) + mul!(_temp3, _ᶜᶠmat, _temp2) + mul!(_result, _ᶜᶜmat, _temp3) + end, + ) + + test_func_against_array_reference(; + test_name = "linear combination of matrix products and LinearAlgebra.I", + result = (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ( + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> + (@. result = 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + ) -> begin + @. _temp1 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp2, _temp1, _ᶜᶜmat) + mul!(_temp3, _temp2, _ᶜᶠmat) + mul!(_temp4, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_result, 4I) # We can't directly use I in array broadcasts. + @. _result = _temp3 + _temp4 / 3 - _result + end, + ) + + test_func_against_array_reference(; + test_name = "another linear combination of matrix products and \ + LinearAlgebra.I", + result = (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ( + (@. ᶠᶜmat ⋅ ᶜᶜmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat / 3), + (@. (ᶠᶠmat / 3) ⋅ ᶠᶠmat), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + ) -> begin + mul!(_temp1, _ᶠᶜmat, _ᶜᶜmat) + mul!(_temp2, _temp1, _ᶜᶠmat) + @. _temp3 = 0 + _ᶠᶠmat / 3 # This allocates without the `0 + `. + mul!(_temp4, _temp3, _ᶠᶠmat) + copyto!(_result, 4I) # We can't directly use I in array broadcasts. + @. _result = _temp2 * 2 - _temp4 + _result + end, + ) + + test_func_against_array_reference(; + test_name = "matrix times linear combination", + result = (@. ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,))), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ( + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + ᶜᶠmat ⋅ (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,))), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + ) -> begin + @. _temp1 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp2, _temp1, _ᶜᶜmat) + mul!(_temp3, _temp2, _ᶜᶠmat) + mul!(_temp4, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_temp5, 4I) # We can't directly use I in array broadcasts. + @. _temp5 = _temp3 + _temp4 / 3 - _temp5 + mul!(_result, _ᶜᶠmat, _temp5) + end, + ) + + test_func_against_array_reference(; + test_name = "linear combination times another linear combination", + result = (@. (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,))), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ( + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + (@. ᶠᶜmat ⋅ ᶜᶜmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat / 3), + (@. (ᶠᶠmat / 3) ⋅ ᶠᶠmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,))), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + _temp6, + _temp7, + _temp8, + _temp9, + _temp10, + ) -> begin + @. _temp1 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp2, _temp1, _ᶜᶜmat) + mul!(_temp3, _temp2, _ᶜᶠmat) + mul!(_temp4, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_temp5, 4I) # We can't directly use I in array broadcasts. + @. _temp5 = _temp3 + _temp4 / 3 - _temp5 + mul!(_temp6, _ᶠᶜmat, _ᶜᶜmat) + mul!(_temp7, _temp6, _ᶜᶠmat) + @. _temp8 = 0 + _ᶠᶠmat / 3 # This allocates without the `0 + `. + mul!(_temp9, _temp8, _ᶠᶠmat) + copyto!(_temp10, 4I) # We can't directly use I in array broadcasts. + @. _temp10 = _temp7 * 2 - _temp9 + _temp10 + mul!(_result, _temp5, _temp10) + end, + ) + + test_func_against_array_reference(; + test_name = "matrix times matrix times linear combination times matrix \ + times another linear combination times matrix", + result = (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)) ⋅ + ᶠᶠmat), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat), + temp_values = ( + (@. ᶠᶜmat ⋅ ᶜᶠmat), + (@. 2 * ᶠᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat ⋅ ᶠᶠmat), + (@. 2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)), + (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,))), + (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat), + (@. ᶠᶠmat / 3), + (@. (ᶠᶠmat / 3) ⋅ ᶠᶠmat), + (@. ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)), + (@. ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,))), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat) -> (@. result = + ᶠᶜmat ⋅ ᶜᶠmat ⋅ + (2 * ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat + ᶠᶠmat ⋅ ᶠᶠmat / 3 - (4I,)) ⋅ + ᶠᶠmat ⋅ + (ᶠᶜmat ⋅ ᶜᶜmat ⋅ ᶜᶠmat * 2 - (ᶠᶠmat / 3) ⋅ ᶠᶠmat + (4I,)) ⋅ + ᶠᶠmat), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + _temp6, + _temp7, + _temp8, + _temp9, + _temp10, + _temp11, + _temp12, + _temp13, + _temp14, + ) -> begin + mul!(_temp1, _ᶠᶜmat, _ᶜᶠmat) + @. _temp2 = 0 + 2 * _ᶠᶜmat # This allocates without the `0 + `. + mul!(_temp3, _temp2, _ᶜᶜmat) + mul!(_temp4, _temp3, _ᶜᶠmat) + mul!(_temp5, _ᶠᶠmat, _ᶠᶠmat) + copyto!(_temp6, 4I) # We can't directly use I in array broadcasts. + @. _temp6 = _temp4 + _temp5 / 3 - _temp6 + mul!(_temp7, _temp1, _temp6) + mul!(_temp8, _temp7, _ᶠᶠmat) + mul!(_temp9, _ᶠᶜmat, _ᶜᶜmat) + mul!(_temp10, _temp9, _ᶜᶠmat) + @. _temp11 = 0 + _ᶠᶠmat / 3 # This allocates without the `0 + `. + mul!(_temp12, _temp11, _ᶠᶠmat) + copyto!(_temp13, 4I) # We can't directly use I in array broadcasts. + @. _temp13 = _temp10 * 2 - _temp12 + _temp13 + mul!(_temp14, _temp8, _temp13) + mul!(_result, _temp14, _ᶠᶠmat) + end, + ) + + test_func_against_array_reference(; + test_name = "nested matrix constructions and multiplications", + result = (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1) ⋅ ᶠᶠmat ⋅ + DiagonalMatrixRow(DiagonalMatrixRow(ᶠvec) ⋅ ᶠvec)), + inputs = (ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec), + temp_values = ( + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec)), + (@. TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1)), + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1)), + (@. BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1) ⋅ ᶠᶠmat), + (@. DiagonalMatrixRow(ᶠvec)), + (@. DiagonalMatrixRow(DiagonalMatrixRow(ᶠvec) ⋅ ᶠvec)), + ), + func! = (result, ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec) -> + (@. result = + BidiagonalMatrixRow(ᶜᶠmat ⋅ ᶠvec, ᶜᶜmat ⋅ ᶜvec) ⋅ + TridiagonalMatrixRow(ᶠvec, ᶠᶜmat ⋅ ᶜvec, 1) ⋅ ᶠᶠmat ⋅ + DiagonalMatrixRow(DiagonalMatrixRow(ᶠvec) ⋅ ᶠvec)), + ref_func! = ( + _result, + _ᶜᶜmat, + _ᶜᶠmat, + _ᶠᶠmat, + _ᶠᶜmat, + _ᶜvec, + _ᶠvec, + _temp1, + _temp2, + _temp3, + _temp4, + _temp5, + _temp6, + ) -> begin + mul!(view(_temp1, band(0)), _ᶜᶠmat, _ᶠvec) + mul!(view(_temp1, band(1)), _ᶜᶜmat, _ᶜvec) + copyto!(view(_temp2, band(-1)), 1, _ᶠvec, 2) + mul!(view(_temp2, band(0)), _ᶠᶜmat, _ᶜvec) + fill!(view(_temp2, band(1)), 1) + mul!(_temp3, _temp1, _temp2) + mul!(_temp4, _temp3, _ᶠᶠmat) + copyto!(view(_temp5, band(0)), 1, _ᶠvec, 1) + mul!(view(_temp6, band(0)), _temp5, _ᶠvec) + mul!(_result, _temp4, _temp6) + end, + ) +end + +@testset "Non-scalar Matrix Field Broadcasting" begin + ᶜᶜmat, ᶜᶠmat, ᶠᶠmat, ᶠᶜmat, ᶜvec, ᶠvec = random_test_fields(Float64) + + ᶜlg = Fields.local_geometry_field(ᶜvec) + ᶠlg = Fields.local_geometry_field(ᶠvec) + + ᶜᶠmat2 = map(row -> map(sin, row), ᶜᶠmat) + ᶜᶠmat3 = map(row -> map(cos, row), ᶜᶠmat) + ᶠᶜmat2 = map(row -> map(sin, row), ᶠᶜmat) + ᶠᶜmat3 = map(row -> map(cos, row), ᶠᶜmat) + + ᶜᶠmat_AC1 = map(row -> map(adjoint ∘ Geometry.Covariant1Vector, row), ᶜᶠmat) + ᶜᶠmat_C12 = map( + (row1, row2) -> map(Geometry.Covariant12Vector, row1, row2), + ᶜᶠmat2, + ᶜᶠmat3, + ) + ᶠᶜmat_AC1 = map(row -> map(adjoint ∘ Geometry.Covariant1Vector, row), ᶠᶜmat) + ᶠᶜmat_C12 = map( + (row1, row2) -> map(Geometry.Covariant12Vector, row1, row2), + ᶠᶜmat2, + ᶠᶜmat3, + ) + + ᶜᶠmat_AC1_num = + map((row1, row2) -> map(tuple, row1, row2), ᶜᶠmat_AC1, ᶜᶠmat) + ᶜᶠmat_num_C12 = + map((row1, row2) -> map(tuple, row1, row2), ᶜᶠmat, ᶜᶠmat_C12) + ᶠᶜmat_C12_AC1 = + map((row1, row2) -> map(tuple, row1, row2), ᶠᶜmat_C12, ᶠᶜmat_AC1) + + test_func_against_reference(; + test_name = "matrix of covectors times matrix of vectors", + result = (@. ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12), + inputs = (ᶜᶠmat_AC1, ᶠᶜmat_C12), + ref_inputs = (ᶜᶠmat, ᶠᶜmat2, ᶠᶜmat3, ᶠlg), + func! = (result, ᶜᶠmat_AC1, ᶠᶜmat_C12) -> + (@. result = ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12), + ref_func! = (result, ᶜᶠmat, ᶠᶜmat2, ᶠᶜmat3, ᶠlg) -> (@. result = + ᶜᶠmat ⋅ ( + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ⋅ ᶠᶜmat2 + + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3 + )), + ) + + test_func_against_reference(; + test_name = "matrix of covectors times matrix of vectors times matrix \ + of numbers times matrix of covectors times matrix of \ + vectors", + result = (@. ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12 ⋅ ᶜᶠmat ⋅ ᶠᶜmat_AC1 ⋅ ᶜᶠmat_C12), + inputs = (ᶜᶠmat_AC1, ᶠᶜmat_C12, ᶜᶠmat, ᶠᶜmat_AC1, ᶜᶠmat_C12), + ref_inputs = (ᶜᶠmat, ᶜᶠmat2, ᶜᶠmat3, ᶠᶜmat, ᶠᶜmat2, ᶠᶜmat3, ᶜlg, ᶠlg), + func! = (result, ᶜᶠmat_AC1, ᶠᶜmat_C12, ᶜᶠmat, ᶠᶜmat_AC1, ᶜᶠmat_C12) -> + (@. result = ᶜᶠmat_AC1 ⋅ ᶠᶜmat_C12 ⋅ ᶜᶠmat ⋅ ᶠᶜmat_AC1 ⋅ ᶜᶠmat_C12), + ref_func! = ( + result, + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶜlg, + ᶠlg, + ) -> (@. result = + ᶜᶠmat ⋅ ( + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ⋅ ᶠᶜmat2 + + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3 + ) ⋅ ᶜᶠmat ⋅ ᶠᶜmat ⋅ ( + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ⋅ ᶜᶠmat2 + + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ⋅ ᶜᶠmat3 + )), + ) + + test_func_against_reference(; + test_name = "matrix of covectors and numbers times matrix of vectors \ + and covectors times matrix of numbers and vectors times \ + vector of numbers", + result = (@. ᶜᶠmat_AC1_num ⋅ ᶠᶜmat_C12_AC1 ⋅ ᶜᶠmat_num_C12 ⋅ ᶠvec), + inputs = (ᶜᶠmat_AC1_num, ᶠᶜmat_C12_AC1, ᶜᶠmat_num_C12, ᶠvec), + ref_inputs = ( + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶠvec, + ᶜlg, + ᶠlg, + ), + func! = (result, ᶜᶠmat_AC1_num, ᶠᶜmat_C12_AC1, ᶜᶠmat_num_C12, ᶠvec) -> + (@. result = ᶜᶠmat_AC1_num ⋅ ᶠᶜmat_C12_AC1 ⋅ ᶜᶠmat_num_C12 ⋅ ᶠvec), + ref_func! = ( + result, + ᶜᶠmat, + ᶜᶠmat2, + ᶜᶠmat3, + ᶠᶜmat, + ᶠᶜmat2, + ᶠᶜmat3, + ᶠvec, + ᶜlg, + ᶠlg, + ) -> (@. result = tuple( + ᶜᶠmat ⋅ ( + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:1) ⋅ ᶠᶜmat2 + + DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3 + ) ⋅ ᶜᶠmat ⋅ ᶠvec, + ᶜᶠmat ⋅ ᶠᶜmat ⋅ ( + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:1) ⋅ ᶜᶠmat2 + + DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ⋅ ᶜᶠmat3 + ) ⋅ ᶠvec, + )), + ) +end diff --git a/test/MatrixFields/rmul_with_projection.jl b/test/MatrixFields/rmul_with_projection.jl new file mode 100644 index 0000000000..b4fd8c4844 --- /dev/null +++ b/test/MatrixFields/rmul_with_projection.jl @@ -0,0 +1,127 @@ +using Test +using JET +using Random: seed! +using StaticArrays: @SMatrix + +import ClimaCore: Geometry +import ClimaCore.MatrixFields: + rmul_with_projection, rmul_with_projection_return_type + +import ClimaCore.RecursiveApply: rsub + +function test_rmul_with_projection(x, y, lg, expected_result) + result = rmul_with_projection(x, y, lg) + result_type = rmul_with_projection_return_type(typeof(x), typeof(y)) + + # Check for correctness. + @info "$(result == expected_result)\n\t$result\n\t$expected_result\n\t\ + $(rsub(result, expected_result))" + @test result == expected_result # TODO: Why doesn't this work on CI? + @test result_type == typeof(result) + + # Check for inference failures. + @test_opt rmul_with_projection(x, y, lg) + @test_opt rmul_with_projection_return_type(typeof(x), typeof(y)) +end + +@testset "rmul_with_projection Unit Tests" begin + seed!(1) # ensures reproducibility + + FT = Float64 + coord = Geometry.LatLongZPoint(rand(FT), rand(FT), rand(FT)) + ∂x∂ξ = Geometry.AxisTensor( + (Geometry.LocalAxis{(1, 2, 3)}(), Geometry.CovariantAxis{(1, 2, 3)}()), + (@SMatrix rand(FT, 3, 3)), + ) + lg = Geometry.LocalGeometry(coord, rand(FT), rand(FT), ∂x∂ξ) + + number = rand(FT) + covector = Geometry.Covariant12Vector(rand(FT), rand(FT))' + vector = Geometry.Covariant123Vector(rand(FT), rand(FT), rand(FT)) + tensor = vector * vector' + projected_vector = + Geometry.project(Geometry.Contravariant12Axis(), vector, lg) + projected_tensor = + Geometry.project(Geometry.Contravariant12Axis(), tensor, lg) + + # Test all required combinations of single values. + test_rmul_with_projection(number, number, lg, number * number) + test_rmul_with_projection(number, covector, lg, number * covector) + test_rmul_with_projection(number, vector, lg, number * vector) + test_rmul_with_projection(number, tensor, lg, number * tensor) + test_rmul_with_projection(covector, number, lg, covector * number) + test_rmul_with_projection(vector, number, lg, vector * number) + test_rmul_with_projection(tensor, number, lg, tensor * number) + test_rmul_with_projection(covector, vector, lg, covector * projected_vector) + test_rmul_with_projection(covector, tensor, lg, covector * projected_tensor) + + # Test some combinations of complicated nested values. + nested_type(value1, value2, value3) = + (; a = (), b = value1, c = (value2, value3, (;))) + test_rmul_with_projection( + number, + nested_type(covector, vector, tensor), + lg, + nested_type(number * covector, number * vector, number * tensor), + ) + test_rmul_with_projection( + nested_type(covector, vector, tensor), + number, + lg, + nested_type(covector * number, vector * number, tensor * number), + ) + test_rmul_with_projection( + vector, + nested_type(number, number, number), + lg, + nested_type(vector * number, vector * number, vector * number), + ) + test_rmul_with_projection( + nested_type(number, number, number), + covector, + lg, + nested_type(number * covector, number * covector, number * covector), + ) + test_rmul_with_projection( + nested_type(number, vector, number), + nested_type(covector, number, tensor), + lg, + nested_type(number * covector, vector * number, number * tensor), + ) + test_rmul_with_projection( + nested_type(covector, number, tensor), + nested_type(number, vector, number), + lg, + nested_type(covector * number, number * vector, tensor * number), + ) + test_rmul_with_projection( + covector, + nested_type(vector, number, tensor), + lg, + nested_type( + covector * projected_vector, + covector * number, + covector * projected_tensor, + ), + ) + test_rmul_with_projection( + nested_type(covector, number, covector), + vector, + lg, + nested_type( + covector * projected_vector, + number * vector, + covector * projected_vector, + ), + ) + test_rmul_with_projection( + nested_type(covector, number, covector), + nested_type(number, vector, tensor), + lg, + nested_type( + covector * number, + number * vector, + covector * projected_tensor, + ), + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 58ddeb849d..8f6b09ee1a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -75,6 +75,9 @@ if !Sys.iswindows() @safetestset "Hybrid - dss opt" begin @time include("Operators/hybrid/dss_opt.jl") end @safetestset "Hybrid - opt" begin @time include("Operators/hybrid/opt.jl") end + @safetestset "MatrixFields - rmul_with_projection" begin @time include("MatrixFields/rmul_with_projection.jl") end + @safetestset "MatrixFields - matrix field broadcasting" begin @time include("MatrixFields/matrix_field_broadcasting.jl") end + @safetestset "Hypsography - 2d" begin @time include("Hypsography/2d.jl") end @safetestset "Hypsography - 3d sphere" begin @time include("Hypsography/3dsphere.jl") end