Skip to content

Commit

Permalink
Add new MatrixFields module, along with unit tests and performance tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Jun 15, 2023
1 parent 8e77c96 commit f15ba20
Show file tree
Hide file tree
Showing 14 changed files with 1,739 additions and 12 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.10.39"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
Expand Down
8 changes: 7 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.4.2"

[[deps.BandedMatrices]]
deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "PrecompileTools", "SparseArrays"]
git-tree-sha1 = "9ad46355045491b12eab409dee73e9de46293aa2"
uuid = "aae01518-5342-5314-be14-df237901396f"
version = "0.17.28"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

Expand Down Expand Up @@ -216,7 +222,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.4.2"

[[deps.ClimaCore]]
deps = ["Adapt", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DiffEqBase", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "Rotations", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.39"
Expand Down
1 change: 1 addition & 0 deletions src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include("Topologies/Topologies.jl")
include("Spaces/Spaces.jl")
include("Fields/Fields.jl")
include("Operators/Operators.jl")
include("MatrixFields/MatrixFields.jl")
include("Hypsography/Hypsography.jl")
include("Limiters/Limiters.jl")
include("InputOutput/InputOutput.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/Fields/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Base.map(fn, field::Field) = Base.broadcast(fn, field)
Base.map(fn, fields::Field...) = Base.broadcast(fn, fields...)

"""
Fields.local_sum(v::Field)
Expand Down
23 changes: 21 additions & 2 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,30 @@ Base.propertynames(x::AxisVector) = symbols(axes(x, 1))
end
end

const AdjointAxisTensor{T, N, A, S} = Adjoint{T, AxisTensor{T, N, A, S}}

Base.show(io::IO, a::AdjointAxisTensor{T, N, A, S}) where {T, N, A, S} =
print(io, "adjoint($(a'))")

components(a::AdjointAxisTensor) = components(parent(a))'

Base.zero(a::AdjointAxisTensor) = zero(a')'
Base.zero(::Type{AdjointAxisTensor{T, N, A, S}}) where {T, N, A, S} =
zero(AxisTensor{T, N, A, S})'

@inline +(a::AdjointAxisTensor) = (+a')'
@inline -(a::AdjointAxisTensor) = (-a')'
@inline +(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' + b')'
@inline -(a::AdjointAxisTensor, b::AdjointAxisTensor) = (a' - b')'
@inline *(a::Number, b::AdjointAxisTensor) = (a * b')'
@inline *(a::AdjointAxisTensor, b::Number) = (a' * b)'
@inline /(a::AdjointAxisTensor, b::Number) = (a' / b)'
@inline \(a::Number, b::AdjointAxisTensor) = (a' \ b)'

@inline (==)(a::AdjointAxisTensor, b::AdjointAxisTensor) = a' == b'

const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}

components(va::AdjointAxisVector) = components(parent(va))'
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int) =
getindex(components(va), i)
Base.@propagate_inbounds Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =
Expand All @@ -286,7 +306,6 @@ Axis2Tensor(
) = AxisTensor(axes, components)

const AdjointAxis2Tensor{T, A, S} = Adjoint{T, Axis2Tensor{T, A, S}}
components(va::AdjointAxis2Tensor) = components(parent(va))'

const Axis2TensorOrAdj{T, A, S} =
Union{Axis2Tensor{T, A, S}, AdjointAxis2Tensor{T, A, S}}
Expand Down
35 changes: 35 additions & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module MatrixFields

import LinearAlgebra: UniformScaling, Adjoint
import StaticArrays: SArray, SMatrix, SVector
import BandedMatrices: BandedMatrix, band
import ..Utilities: PlusHalf, half
import ..RecursiveApply: rmap, rmaptype, rzero, radd, rsub, rmul, rdiv
import ..Geometry
import ..Spaces
import ..Fields
import ..Operators:
FiniteDifferenceOperator,
BoundaryCondition,
LeftBoundaryWindow,
RightBoundaryWindow,
has_boundary,
get_boundary,
stencil_interior_width,
boundary_width,
return_eltype,
return_space,
stencil_interior,
stencil_left_boundary,
stencil_right_boundary,
left_idx,
right_idx,
getidx,
getidx_args

include("band_matrix_row.jl")
include("rmul_with_projection.jl")
include("matrix_multiplication.jl")
include("matrix_field_utils.jl")

end
153 changes: 153 additions & 0 deletions src/MatrixFields/band_matrix_row.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
BandMatrixRow{ld}(entries...)
Stores the nonzero entries in a row of a band matrix, starting with the lowest
diagonal, which has index `ld`. Supported operations include accessing the entry
on the diagonal with index `d` by calling `row[d]`, taking linear combinations
with other band matrix rows, and checking for equality with other band matrix
rows. There are several aliases defined for commonly-used subtypes of
`BandMatrixRow` (with `T` denoting the type of the row's entries):
- `DiagonalMatrixRow{T}`
- `BidiagonalMatrixRow{T}`
- `TridiagonalMatrixRow{T}`
- `QuaddiagonalMatrixRow{T}`
- `PentadiagonalMatrixRow{T}`
"""
struct BandMatrixRow{ld, bw, T} # bw is the bandwidth (the number of diagonals)
entries::NTuple{bw, T}
end
BandMatrixRow{ld}(entries::Vararg{Any, bw}) where {ld, bw} =
BandMatrixRow{ld, bw}(entries...)
function BandMatrixRow{ld, bw}(entries::Vararg{Any, bw}) where {ld, bw}
promoted_entries = promote(entries...)
return BandMatrixRow{ld, bw, eltype(promoted_entries)}(promoted_entries)
end

const DiagonalMatrixRow{T} = BandMatrixRow{0, 1, T}
const BidiagonalMatrixRow{T} = BandMatrixRow{-1 + half, 2, T}
const TridiagonalMatrixRow{T} = BandMatrixRow{-1, 3, T}
const QuaddiagonalMatrixRow{T} = BandMatrixRow{-2 + half, 4, T}
const PentadiagonalMatrixRow{T} = BandMatrixRow{-2, 5, T}

function Base.show(
io::IO,
::Type{BMR},
) where {ld, bw, T, BMR <: BandMatrixRow{ld, bw, T}}
string = if BMR <: DiagonalMatrixRow
"DiagonalMatrixRow{$T}"
elseif BMR <: BidiagonalMatrixRow
"BidiagonalMatrixRow{$T}"
elseif BMR <: TridiagonalMatrixRow
"TridiagonalMatrixRow{$T}"
elseif BMR <: QuaddiagonalMatrixRow
"QuaddiagonalMatrixRow{$T}"
elseif BMR <: PentadiagonalMatrixRow
"PentadiagonalMatrixRow{$T}"
else
"BandMatrixRow{$ld, $bw, $T}"
end
print(io, string)
end

"""
outer_diagonals(::Type{<:BandMatrixRow})
Gets the indices of the lower and upper diagonals, `ld` and `ud`, of the given
subtype of `BandMatrixRow`.
"""
outer_diagonals(::Type{<:BandMatrixRow{ld, bw}}) where {ld, bw} =
(ld, ld + bw - 1)

"""
band_matrix_row_type(ld, ud, T)
A shorthand for getting the subtype of `BandMatrixRow` that has entries of type
`T` on the diagonals with indices in the range `ld:ud`.
"""
band_matrix_row_type(ld, ud, T) = BandMatrixRow{ld, ud - ld + 1, T}

Base.eltype(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} = T

Base.zero(::Type{BandMatrixRow{ld, bw, T}}) where {ld, bw, T} =
BandMatrixRow{ld}(ntuple(_ -> rzero(T), Val(bw))...)

Base.@propagate_inbounds Base.getindex(row::BandMatrixRow{ld}, d) where {ld} =
row.entries[d - ld + 1]

function Base.promote_rule(
::Type{BMR1},
::Type{BMR2},
) where {BMR1 <: BandMatrixRow, BMR2 <: BandMatrixRow}
ld1, ud1 = outer_diagonals(BMR1)
ld2, ud2 = outer_diagonals(BMR2)
typeof(ld1) == typeof(ld2) || error(
"Cannot promote the $(ld1 isa PlusHalf ? "non-" : "")square matrix \
row type $BMR1 and the $(ld2 isa PlusHalf ? "non-" : "")square \
matrix row type $BMR2 to a common type",
)
T = promote_type(eltype(BMR1), eltype(BMR2))
return band_matrix_row_type(min(ld1, ld2), max(ud1, ud2), T)
end

Base.promote_rule(
::Type{BMR},
::Type{US},
) where {BMR <: BandMatrixRow, US <: UniformScaling} =
promote_rule(BMR, DiagonalMatrixRow{eltype(US)})

function Base.convert(
::Type{BMR},
row::BandMatrixRow,
) where {BMR <: BandMatrixRow}
old_ld, old_ud = outer_diagonals(typeof(row))
new_ld, new_ud = outer_diagonals(BMR)
typeof(old_ld) == typeof(new_ld) ||
error("Cannot convert a $(old_ld isa PlusHalf ? "non-" : "")square \
matrix row of type $(typeof(row)) to the \
$(new_ld isa PlusHalf ? "non-" : "")square matrix row type $BMR")
new_ld <= old_ld && new_ud >= old_ud ||
error("Cannot convert a $(typeof(row)) to a $BMR, since that would \
require dropping potentially non-zero row entries")
first_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(old_ld - new_ld))
entries = map(entry -> convert(eltype(BMR), entry), row.entries)
last_zeros = ntuple(_ -> rzero(eltype(BMR)), Val(new_ud - old_ud))
return BandMatrixRow{new_ld}(first_zeros..., entries..., last_zeros...)
end

Base.convert(::Type{BMR}, row::UniformScaling) where {BMR <: BandMatrixRow} =
convert(BMR, DiagonalMatrixRow(row.λ))

Base.:(==)(row1::BMR, row2::BMR) where {BMR <: BandMatrixRow} =
row1.entries == row2.entries
Base.:(==)(row1::BandMatrixRow, row2::BandMatrixRow) =
==(promote(row1, row2)...)
Base.:(==)(row1::BandMatrixRow, row2::UniformScaling) =
==(promote(row1, row2)...)
Base.:(==)(row1::UniformScaling, row2::BandMatrixRow) =
==(promote(row1, row2)...)

Base.map(f::F, rows::BandMatrixRow{ld}...) where {F, ld} =
BandMatrixRow{ld}(map(f, map(row -> row.entries, rows)...)...)

# Define all necessary operations for computing linear combinations. Use
# methods from RecursiveApply to handle nested types.
Base.:+(row::BandMatrixRow) = map(radd, row)
Base.:+(row1::BandMatrixRow, row2::BandMatrixRow) =
map(radd, promote(row1, row2)...)
Base.:+(row1::BandMatrixRow, row2::UniformScaling) =
map(radd, promote(row1, row2)...)
Base.:+(row1::UniformScaling, row2::BandMatrixRow) =
map(radd, promote(row1, row2)...)
Base.:-(row::BandMatrixRow) = map(rsub, row)
Base.:-(row1::BandMatrixRow, row2::BandMatrixRow) =
map(rsub, promote(row1, row2)...)
Base.:-(row1::BandMatrixRow, row2::UniformScaling) =
map(rsub, promote(row1, row2)...)
Base.:-(row1::UniformScaling, row2::BandMatrixRow) =
map(rsub, promote(row1, row2)...)
Base.:*(row::BandMatrixRow, value::Number) =
map(entry -> rmul(entry, value), row)
Base.:*(value::Number, row::BandMatrixRow) =
map(entry -> rmul(value, entry), row)
Base.:/(row::BandMatrixRow, value::Number) =
map(entry -> rdiv(entry, value), row)
85 changes: 85 additions & 0 deletions src/MatrixFields/matrix_field_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
column_field2array(field)
Converts a field defined on a `FiniteDifferenceSpace` into a `Vector` or a
`BandedMatrix`, depending on whether or not the elements of the field are
`BandMatrixRow`s. This involves copying the data stored in the field.
"""
function column_field2array(field)
space = axes(field)
space isa Spaces.FiniteDifferenceSpace ||
error("column_field2array requires a field on a FiniteDifferenceSpace")
n_rows = Spaces.nlevels(space)
if eltype(field) <: BandMatrixRow # field represents a matrix
field_ld, field_ud = outer_diagonals(eltype(field))

# Find the amount by which the field's diagonal indices get shifted when
# it is converted into a matrix, as well as the diagonal index of the
# value that ends up in the bottom-right corner of the matrix.
matrix_d_minus_field_d, bottom_corner_matrix_d =
if field_ld isa PlusHalf
if axes(field).staggering isa Spaces.CellCenter
half, 1 # field is a face-to-center matrix
else
-half, -1 # field is a center-to-face matrix
end
else
0, 0 # field is either a center-to-center or face-to-face matrix
end

matrix_ld = field_ld + matrix_d_minus_field_d
matrix_ud = field_ud + matrix_d_minus_field_d
matrix_ld <= 0 && matrix_ud >= 0 ||
error("BandedMatrices.jl does not yet support matrices that have \
diagonals with indices in the range $matrix_ld:$matrix_ud")

n_cols = n_rows + bottom_corner_matrix_d
matrix = BandedMatrix{eltype(eltype(field))}(
undef,
(n_rows, n_cols),
(-matrix_ld, matrix_ud),
)
for (index_of_field_entry, matrix_d) in enumerate(matrix_ld:matrix_ud)
# Find the rows for which field_diagonal[row] is inside the matrix.
# Note: The matrix index (1, 1) corresponds to the diagonal index 0,
# and the matrix index (n_rows, n_cols) corresponds to the diagonal
# index bottom_corner_matrix_d.
first_row = matrix_d < 0 ? 1 - matrix_d : 1
last_row =
matrix_d < bottom_corner_matrix_d ? n_rows : n_cols - matrix_d

# Copy the value in each row from field_diagonal to matrix_diagonal.
field_diagonal = field.entries.:($index_of_field_entry)
matrix_diagonal = view(matrix, band(matrix_d))
for (index_along_diagonal, row) in enumerate(first_row:last_row)
matrix_diagonal[index_along_diagonal] =
Fields.field_values(field_diagonal)[row]
end
end
return matrix
else # field represents a vector
return map(i -> Fields.field_values(field)[i], 1:n_rows)
end
end

"""
field2arrays(field)
Converts a field defined on a `FiniteDifferenceSpace` or on an
`ExtrudedFiniteDifferenceSpace` into a tuple of arrays, each of which
corresponds to a column of the field. This is done by calling
`column_field2array` on each of the field's columns.
"""
function field2arrays(field)
space = axes(field)
column_indices = if space isa Spaces.FiniteDifferenceSpace
(((1, 1), 1),)
elseif space isa Spaces.ExtrudedFiniteDifferenceSpace
(Spaces.all_nodes(Spaces.horizontal_space(space))...,)
else
error("Invalid space type: $(typeof(space).name.wrapper)")
end
return map(column_indices) do ((i, j), h)
return column_field2array(Spaces.column(field, i, j, h))
end
end
Loading

0 comments on commit f15ba20

Please sign in to comment.