Skip to content

Commit

Permalink
Formatting and changes from #1334
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Jun 16, 2023
1 parent a4c796a commit b1f992c
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 158 deletions.
20 changes: 1 addition & 19 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,7 @@ import ..RecursiveApply: rmap, rmaptype, rzero, radd, rsub, rmul, rdiv
import ..Geometry
import ..Spaces
import ..Fields
import ..Operators:
FiniteDifferenceOperator,
AbstractBoundaryCondition,
LeftBoundaryWindow,
RightBoundaryWindow,
has_boundary,
get_boundary,
stencil_interior_width,
left_interior_idx,
right_interior_idx,
left_idx,
right_idx,
return_eltype,
return_space,
stencil_interior,
stencil_left_boundary,
stencil_right_boundary,
reconstruct_placeholder_space,
getidx
import ..Operators

export
export DiagonalMatrixRow,
Expand Down
131 changes: 71 additions & 60 deletions src/MatrixFields/matrix_multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ An operator that multiplies a columnwise band matrix field (a field of
i.e., matrix-vector or matrix-matrix multiplication. The `⋅` symbol is an alias
for `MultiplyColumnwiseBandMatrixField()`.
"""
struct MultiplyColumnwiseBandMatrixField <: FiniteDifferenceOperator end
struct MultiplyColumnwiseBandMatrixField <: Operators.FiniteDifferenceOperator end
const = MultiplyColumnwiseBandMatrixField()

#=
Expand Down Expand Up @@ -119,79 +119,86 @@ always true. Rearranging the remaining two inequalities gives us
ld1 + ld2 ≤ prod_d ≤ ud1 + ud2.
=#

struct TopLeftMatrixCorner <: AbstractBoundaryCondition end
struct BottomRightMatrixCorner <: AbstractBoundaryCondition end
struct TopLeftMatrixCorner <: Operators.AbstractBoundaryCondition end
struct BottomRightMatrixCorner <: Operators.AbstractBoundaryCondition end

has_boundary(
Operators.has_boundary(
::MultiplyColumnwiseBandMatrixField,
::LeftBoundaryWindow{name},
::Operators.LeftBoundaryWindow{name},
) where {name} = true
has_boundary(
Operators.has_boundary(
::MultiplyColumnwiseBandMatrixField,
::RightBoundaryWindow{name},
::Operators.RightBoundaryWindow{name},
) where {name} = true

get_boundary(
Operators.get_boundary(
::MultiplyColumnwiseBandMatrixField,
::LeftBoundaryWindow{name},
::Operators.LeftBoundaryWindow{name},
) where {name} = TopLeftMatrixCorner()
get_boundary(
Operators.get_boundary(
::MultiplyColumnwiseBandMatrixField,
::RightBoundaryWindow{name},
::Operators.RightBoundaryWindow{name},
) where {name} = BottomRightMatrixCorner()

stencil_interior_width(::MultiplyColumnwiseBandMatrixField, matrix1, arg) =
((0, 0), outer_diagonals(eltype(matrix1)))
Operators.stencil_interior_width(
::MultiplyColumnwiseBandMatrixField,
matrix1,
arg,
) = ((0, 0), outer_diagonals(eltype(matrix1)))

_left_interior_idx(space, lbw::Integer) = left_idx(space) - lbw
_left_interior_idx(space, lbw::Integer) = Operators.left_idx(space) - lbw
_left_interior_idx(
space::Union{
Spaces.FaceFiniteDifferenceSpace,
Spaces.FaceExtrudedFiniteDifferenceSpace,
Spaces.CenterFiniteDifferenceSpace,
Spaces.CenterExtrudedFiniteDifferenceSpace,
},
lbw::PlusHalf,
) = left_idx(space) - lbw + half
) = Operators.left_idx(space) - lbw - half
_left_interior_idx(
space::Union{
Spaces.CenterFiniteDifferenceSpace,
Spaces.CenterExtrudedFiniteDifferenceSpace,
Spaces.FaceFiniteDifferenceSpace,
Spaces.FaceExtrudedFiniteDifferenceSpace,
},
lbw::PlusHalf,
) = left_idx(space) - lbw - half
) = Operators.left_idx(space) - lbw + half

left_interior_idx(
Operators.left_interior_idx(
space::Spaces.AbstractSpace,
::MultiplyColumnwiseBandMatrixField,
::TopLeftMatrixCorner,
matrix1,
arg,
) = _left_interior_idx(space, outer_diagonals(eltype(matrix1))[1])

_right_interior_idx(space, rbw::Integer) = right_idx(space) - rbw
_right_interior_idx(space, rbw::Integer) = Operators.right_idx(space) - rbw
_right_interior_idx(
space::Union{
Spaces.FaceFiniteDifferenceSpace,
Spaces.FaceExtrudedFiniteDifferenceSpace,
Spaces.CenterFiniteDifferenceSpace,
Spaces.CenterExtrudedFiniteDifferenceSpace,
},
rbw::PlusHalf,
) = right_idx(space) - rbw - half
) = Operators.right_idx(space) - rbw + half
_right_interior_idx(
space::Union{
Spaces.CenterFiniteDifferenceSpace,
Spaces.CenterExtrudedFiniteDifferenceSpace,
Spaces.FaceFiniteDifferenceSpace,
Spaces.FaceExtrudedFiniteDifferenceSpace,
},
rbw::PlusHalf,
) = right_idx(space) - rbw + half
) = Operators.right_idx(space) - rbw - half

right_interior_idx(
Operators.right_interior_idx(
space::Spaces.AbstractSpace,
::MultiplyColumnwiseBandMatrixField,
::BottomRightMatrixCorner,
matrix1,
arg,
) = _right_interior_idx(space, outer_diagonals(eltype(matrix1))[2])

function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg)
function Operators.return_eltype(
::MultiplyColumnwiseBandMatrixField,
matrix1,
arg,
)
eltype(matrix1) <: BandMatrixRow ||
error("The first argument of ⋅ must be a band matrix field, but the \
given argument is a field with element type $(eltype(matrix1))")
Expand All @@ -200,11 +207,11 @@ function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg)
ld1, ud1 = outer_diagonals(eltype(matrix1))
ld2, ud2 = outer_diagonals(eltype(matrix2))
prod_ld, prod_ud = ld1 + ld2, ud1 + ud2
prod_eltype = rmul_with_projection_return_type(
prod_entry_type = rmul_with_projection_return_type(
eltype(eltype(matrix1)),
eltype(eltype(matrix2)),
)
return band_matrix_row_type(prod_ld, prod_ud, prod_eltype)
return band_matrix_row_type(prod_ld, prod_ud, prod_entry_type)
else # matrix-vector multiplication
vector = arg
return rmul_with_projection_return_type(
Expand All @@ -214,7 +221,8 @@ function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg)
end
end

return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) = space1
Operators.return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) =
space1

# TODO: Use @propagate_inbounds here, and remove @inbounds from this function.
# As of Julia 1.8, doing this increases compilation by more than an order of
Expand All @@ -233,30 +241,29 @@ function multiply_columnwise_band_matrix_at_index(
boundary_ld1 = nothing,
boundary_ud1 = nothing,
)
matrix1_row = getidx(space, matrix1, loc, idx, hidx)
matrix1_row = Operators.getidx(space, matrix1, loc, idx, hidx)
ld1, ud1 = outer_diagonals(eltype(matrix1))
ld1_or_boundary_ld1 = isnothing(boundary_ld1) ? ld1 : boundary_ld1
ud1_or_boundary_ud1 = isnothing(boundary_ud1) ? ud1 : boundary_ud1
return_type = return_eltype(, matrix1, arg)
prod_type = Operators.return_eltype(, matrix1, arg)
if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication
matrix2 = arg
matrix2_rows = BandMatrixRow{ld1}(
map((ld1:ud1...,)) do d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if (
(isnothing(boundary_ld1) || d >= boundary_ld1) &&
(isnothing(boundary_ud1) || d <= boundary_ud1)
)
return @inbounds getidx(space, matrix2, loc, idx + d, hidx)
else
return rzero(eltype(matrix2)) # This value is never used.
end
end...,
) # The rows are precomputed to avoid recomputing them multiple times.
matrix2_rows = map((ld1:ud1...,)) do d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if (
(isnothing(boundary_ld1) || d >= boundary_ld1) &&
(isnothing(boundary_ud1) || d <= boundary_ud1)
)
@inbounds Operators.getidx(space, matrix2, loc, idx + d, hidx)
else
rzero(eltype(matrix2)) # This value is never used.
end
end # The rows are precomputed to avoid recomputing them multiple times.
matrix2_rows_wrapper = BandMatrixRow{ld1}(matrix2_rows...)
ld2, ud2 = outer_diagonals(eltype(matrix2))
prod_ld, prod_ud = outer_diagonals(return_type)
zero_value = rzero(eltype(return_type))
prod_ld, prod_ud = outer_diagonals(prod_type)
zero_value = rzero(eltype(prod_type))
prod_entries = map((prod_ld:prod_ud...,)) do prod_d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
Expand All @@ -272,22 +279,22 @@ function multiply_columnwise_band_matrix_at_index(
prod_entry = zero_value
@inbounds for d in min_d:max_d
value1 = matrix1_row[d]
value2 = matrix2_rows[d][prod_d - d]
value2 = matrix2_rows_wrapper[d][prod_d - d]
value2_lg = Geometry.LocalGeometry(space, idx + d, hidx)
prod_entry = radd(
prod_entry,
rmul_with_projection(value1, value2, value2_lg),
)
end # Using this for-loop is currently faster than using mapreduce.
return prod_entry
prod_entry
end
return BandMatrixRow{prod_ld}(prod_entries...)
else # matrix-vector multiplication
vector = arg
prod_value = rzero(return_type)
prod_value = rzero(prod_type)
@inbounds for d in ld1_or_boundary_ld1:ud1_or_boundary_ud1
value1 = matrix1_row[d]
value2 = getidx(space, vector, loc, idx + d, hidx)
value2 = Operators.getidx(space, vector, loc, idx + d, hidx)
value2_lg = Geometry.LocalGeometry(space, idx + d, hidx)
prod_value = radd(
prod_value,
Expand All @@ -298,7 +305,7 @@ function multiply_columnwise_band_matrix_at_index(
end
end

Base.@propagate_inbounds stencil_interior(
Base.@propagate_inbounds Operators.stencil_interior(
::MultiplyColumnwiseBandMatrixField,
loc,
space,
Expand All @@ -315,7 +322,7 @@ Base.@propagate_inbounds stencil_interior(
arg,
)

Base.@propagate_inbounds stencil_left_boundary(
Base.@propagate_inbounds Operators.stencil_left_boundary(
::MultiplyColumnwiseBandMatrixField,
::TopLeftMatrixCorner,
loc,
Expand All @@ -331,11 +338,13 @@ Base.@propagate_inbounds stencil_left_boundary(
hidx,
matrix1,
arg,
left_idx(reconstruct_placeholder_space(axes(arg), space)) - idx,
Operators.left_idx(
Operators.reconstruct_placeholder_space(axes(arg), space),
) - idx,
nothing,
)

Base.@propagate_inbounds stencil_right_boundary(
Base.@propagate_inbounds Operators.stencil_right_boundary(
::MultiplyColumnwiseBandMatrixField,
::BottomRightMatrixCorner,
loc,
Expand All @@ -352,7 +361,9 @@ Base.@propagate_inbounds stencil_right_boundary(
matrix1,
arg,
nothing,
right_idx(reconstruct_placeholder_space(axes(arg), space)) - idx,
Operators.right_idx(
Operators.reconstruct_placeholder_space(axes(arg), space),
) - idx,
)

# For complex matrix field broadcast expressions involving 4 or more matrices,
Expand Down
Loading

0 comments on commit b1f992c

Please sign in to comment.