From b1f992c05c91c9b4960a4d165b22a9b20366e91f Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Fri, 16 Jun 2023 14:20:04 -0700 Subject: [PATCH] Formatting and changes from #1334 --- src/MatrixFields/MatrixFields.jl | 20 +-- src/MatrixFields/matrix_multiplication.jl | 131 ++++++++++-------- src/RecursiveApply/RecursiveApply.jl | 112 +++++++-------- .../MatrixFields/matrix_field_broadcasting.jl | 21 +-- 4 files changed, 126 insertions(+), 158 deletions(-) diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 05316a3c0a..49ad0bed2e 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -8,25 +8,7 @@ import ..RecursiveApply: rmap, rmaptype, rzero, radd, rsub, rmul, rdiv import ..Geometry import ..Spaces import ..Fields -import ..Operators: - FiniteDifferenceOperator, - AbstractBoundaryCondition, - LeftBoundaryWindow, - RightBoundaryWindow, - has_boundary, - get_boundary, - stencil_interior_width, - left_interior_idx, - right_interior_idx, - left_idx, - right_idx, - return_eltype, - return_space, - stencil_interior, - stencil_left_boundary, - stencil_right_boundary, - reconstruct_placeholder_space, - getidx +import ..Operators export ⋅ export DiagonalMatrixRow, diff --git a/src/MatrixFields/matrix_multiplication.jl b/src/MatrixFields/matrix_multiplication.jl index 0f5edf1527..f28a1c191c 100644 --- a/src/MatrixFields/matrix_multiplication.jl +++ b/src/MatrixFields/matrix_multiplication.jl @@ -6,7 +6,7 @@ An operator that multiplies a columnwise band matrix field (a field of i.e., matrix-vector or matrix-matrix multiplication. The `⋅` symbol is an alias for `MultiplyColumnwiseBandMatrixField()`. """ -struct MultiplyColumnwiseBandMatrixField <: FiniteDifferenceOperator end +struct MultiplyColumnwiseBandMatrixField <: Operators.FiniteDifferenceOperator end const ⋅ = MultiplyColumnwiseBandMatrixField() #= @@ -119,47 +119,50 @@ always true. Rearranging the remaining two inequalities gives us ld1 + ld2 ≤ prod_d ≤ ud1 + ud2. =# -struct TopLeftMatrixCorner <: AbstractBoundaryCondition end -struct BottomRightMatrixCorner <: AbstractBoundaryCondition end +struct TopLeftMatrixCorner <: Operators.AbstractBoundaryCondition end +struct BottomRightMatrixCorner <: Operators.AbstractBoundaryCondition end -has_boundary( +Operators.has_boundary( ::MultiplyColumnwiseBandMatrixField, - ::LeftBoundaryWindow{name}, + ::Operators.LeftBoundaryWindow{name}, ) where {name} = true -has_boundary( +Operators.has_boundary( ::MultiplyColumnwiseBandMatrixField, - ::RightBoundaryWindow{name}, + ::Operators.RightBoundaryWindow{name}, ) where {name} = true -get_boundary( +Operators.get_boundary( ::MultiplyColumnwiseBandMatrixField, - ::LeftBoundaryWindow{name}, + ::Operators.LeftBoundaryWindow{name}, ) where {name} = TopLeftMatrixCorner() -get_boundary( +Operators.get_boundary( ::MultiplyColumnwiseBandMatrixField, - ::RightBoundaryWindow{name}, + ::Operators.RightBoundaryWindow{name}, ) where {name} = BottomRightMatrixCorner() -stencil_interior_width(::MultiplyColumnwiseBandMatrixField, matrix1, arg) = - ((0, 0), outer_diagonals(eltype(matrix1))) +Operators.stencil_interior_width( + ::MultiplyColumnwiseBandMatrixField, + matrix1, + arg, +) = ((0, 0), outer_diagonals(eltype(matrix1))) -_left_interior_idx(space, lbw::Integer) = left_idx(space) - lbw +_left_interior_idx(space, lbw::Integer) = Operators.left_idx(space) - lbw _left_interior_idx( space::Union{ - Spaces.FaceFiniteDifferenceSpace, - Spaces.FaceExtrudedFiniteDifferenceSpace, + Spaces.CenterFiniteDifferenceSpace, + Spaces.CenterExtrudedFiniteDifferenceSpace, }, lbw::PlusHalf, -) = left_idx(space) - lbw + half +) = Operators.left_idx(space) - lbw - half _left_interior_idx( space::Union{ - Spaces.CenterFiniteDifferenceSpace, - Spaces.CenterExtrudedFiniteDifferenceSpace, + Spaces.FaceFiniteDifferenceSpace, + Spaces.FaceExtrudedFiniteDifferenceSpace, }, lbw::PlusHalf, -) = left_idx(space) - lbw - half +) = Operators.left_idx(space) - lbw + half -left_interior_idx( +Operators.left_interior_idx( space::Spaces.AbstractSpace, ::MultiplyColumnwiseBandMatrixField, ::TopLeftMatrixCorner, @@ -167,23 +170,23 @@ left_interior_idx( arg, ) = _left_interior_idx(space, outer_diagonals(eltype(matrix1))[1]) -_right_interior_idx(space, rbw::Integer) = right_idx(space) - rbw +_right_interior_idx(space, rbw::Integer) = Operators.right_idx(space) - rbw _right_interior_idx( space::Union{ - Spaces.FaceFiniteDifferenceSpace, - Spaces.FaceExtrudedFiniteDifferenceSpace, + Spaces.CenterFiniteDifferenceSpace, + Spaces.CenterExtrudedFiniteDifferenceSpace, }, rbw::PlusHalf, -) = right_idx(space) - rbw - half +) = Operators.right_idx(space) - rbw + half _right_interior_idx( space::Union{ - Spaces.CenterFiniteDifferenceSpace, - Spaces.CenterExtrudedFiniteDifferenceSpace, + Spaces.FaceFiniteDifferenceSpace, + Spaces.FaceExtrudedFiniteDifferenceSpace, }, rbw::PlusHalf, -) = right_idx(space) - rbw + half +) = Operators.right_idx(space) - rbw - half -right_interior_idx( +Operators.right_interior_idx( space::Spaces.AbstractSpace, ::MultiplyColumnwiseBandMatrixField, ::BottomRightMatrixCorner, @@ -191,7 +194,11 @@ right_interior_idx( arg, ) = _right_interior_idx(space, outer_diagonals(eltype(matrix1))[2]) -function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg) +function Operators.return_eltype( + ::MultiplyColumnwiseBandMatrixField, + matrix1, + arg, +) eltype(matrix1) <: BandMatrixRow || error("The first argument of ⋅ must be a band matrix field, but the \ given argument is a field with element type $(eltype(matrix1))") @@ -200,11 +207,11 @@ function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg) ld1, ud1 = outer_diagonals(eltype(matrix1)) ld2, ud2 = outer_diagonals(eltype(matrix2)) prod_ld, prod_ud = ld1 + ld2, ud1 + ud2 - prod_eltype = rmul_with_projection_return_type( + prod_entry_type = rmul_with_projection_return_type( eltype(eltype(matrix1)), eltype(eltype(matrix2)), ) - return band_matrix_row_type(prod_ld, prod_ud, prod_eltype) + return band_matrix_row_type(prod_ld, prod_ud, prod_entry_type) else # matrix-vector multiplication vector = arg return rmul_with_projection_return_type( @@ -214,7 +221,8 @@ function return_eltype(::MultiplyColumnwiseBandMatrixField, matrix1, arg) end end -return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) = space1 +Operators.return_space(::MultiplyColumnwiseBandMatrixField, space1, space2) = + space1 # TODO: Use @propagate_inbounds here, and remove @inbounds from this function. # As of Julia 1.8, doing this increases compilation by more than an order of @@ -233,30 +241,29 @@ function multiply_columnwise_band_matrix_at_index( boundary_ld1 = nothing, boundary_ud1 = nothing, ) - matrix1_row = getidx(space, matrix1, loc, idx, hidx) + matrix1_row = Operators.getidx(space, matrix1, loc, idx, hidx) ld1, ud1 = outer_diagonals(eltype(matrix1)) ld1_or_boundary_ld1 = isnothing(boundary_ld1) ? ld1 : boundary_ld1 ud1_or_boundary_ud1 = isnothing(boundary_ud1) ? ud1 : boundary_ud1 - return_type = return_eltype(⋅, matrix1, arg) + prod_type = Operators.return_eltype(⋅, matrix1, arg) if eltype(arg) <: BandMatrixRow # matrix-matrix multiplication matrix2 = arg - matrix2_rows = BandMatrixRow{ld1}( - map((ld1:ud1...,)) do d - # TODO: Use @propagate_inbounds_meta instead of @inline_meta. - Base.@_inline_meta - if ( - (isnothing(boundary_ld1) || d >= boundary_ld1) && - (isnothing(boundary_ud1) || d <= boundary_ud1) - ) - return @inbounds getidx(space, matrix2, loc, idx + d, hidx) - else - return rzero(eltype(matrix2)) # This value is never used. - end - end..., - ) # The rows are precomputed to avoid recomputing them multiple times. + matrix2_rows = map((ld1:ud1...,)) do d + # TODO: Use @propagate_inbounds_meta instead of @inline_meta. + Base.@_inline_meta + if ( + (isnothing(boundary_ld1) || d >= boundary_ld1) && + (isnothing(boundary_ud1) || d <= boundary_ud1) + ) + @inbounds Operators.getidx(space, matrix2, loc, idx + d, hidx) + else + rzero(eltype(matrix2)) # This value is never used. + end + end # The rows are precomputed to avoid recomputing them multiple times. + matrix2_rows_wrapper = BandMatrixRow{ld1}(matrix2_rows...) ld2, ud2 = outer_diagonals(eltype(matrix2)) - prod_ld, prod_ud = outer_diagonals(return_type) - zero_value = rzero(eltype(return_type)) + prod_ld, prod_ud = outer_diagonals(prod_type) + zero_value = rzero(eltype(prod_type)) prod_entries = map((prod_ld:prod_ud...,)) do prod_d # TODO: Use @propagate_inbounds_meta instead of @inline_meta. Base.@_inline_meta @@ -272,22 +279,22 @@ function multiply_columnwise_band_matrix_at_index( prod_entry = zero_value @inbounds for d in min_d:max_d value1 = matrix1_row[d] - value2 = matrix2_rows[d][prod_d - d] + value2 = matrix2_rows_wrapper[d][prod_d - d] value2_lg = Geometry.LocalGeometry(space, idx + d, hidx) prod_entry = radd( prod_entry, rmul_with_projection(value1, value2, value2_lg), ) end # Using this for-loop is currently faster than using mapreduce. - return prod_entry + prod_entry end return BandMatrixRow{prod_ld}(prod_entries...) else # matrix-vector multiplication vector = arg - prod_value = rzero(return_type) + prod_value = rzero(prod_type) @inbounds for d in ld1_or_boundary_ld1:ud1_or_boundary_ud1 value1 = matrix1_row[d] - value2 = getidx(space, vector, loc, idx + d, hidx) + value2 = Operators.getidx(space, vector, loc, idx + d, hidx) value2_lg = Geometry.LocalGeometry(space, idx + d, hidx) prod_value = radd( prod_value, @@ -298,7 +305,7 @@ function multiply_columnwise_band_matrix_at_index( end end -Base.@propagate_inbounds stencil_interior( +Base.@propagate_inbounds Operators.stencil_interior( ::MultiplyColumnwiseBandMatrixField, loc, space, @@ -315,7 +322,7 @@ Base.@propagate_inbounds stencil_interior( arg, ) -Base.@propagate_inbounds stencil_left_boundary( +Base.@propagate_inbounds Operators.stencil_left_boundary( ::MultiplyColumnwiseBandMatrixField, ::TopLeftMatrixCorner, loc, @@ -331,11 +338,13 @@ Base.@propagate_inbounds stencil_left_boundary( hidx, matrix1, arg, - left_idx(reconstruct_placeholder_space(axes(arg), space)) - idx, + Operators.left_idx( + Operators.reconstruct_placeholder_space(axes(arg), space), + ) - idx, nothing, ) -Base.@propagate_inbounds stencil_right_boundary( +Base.@propagate_inbounds Operators.stencil_right_boundary( ::MultiplyColumnwiseBandMatrixField, ::BottomRightMatrixCorner, loc, @@ -352,7 +361,9 @@ Base.@propagate_inbounds stencil_right_boundary( matrix1, arg, nothing, - right_idx(reconstruct_placeholder_space(axes(arg), space)) - idx, + Operators.right_idx( + Operators.reconstruct_placeholder_space(axes(arg), space), + ) - idx, ) # For complex matrix field broadcast expressions involving 4 or more matrices, diff --git a/src/RecursiveApply/RecursiveApply.jl b/src/RecursiveApply/RecursiveApply.jl index 31add2b7d5..2309cfa0ce 100755 --- a/src/RecursiveApply/RecursiveApply.jl +++ b/src/RecursiveApply/RecursiveApply.jl @@ -11,48 +11,31 @@ export ⊞, ⊠, ⊟ # These functions need to be generated for type stability (since T.parameters is # a SimpleVector, the compiler cannot always infer its size and elements). -has_no_params(::Type{T}) where {T} = - if @generated - :($(isempty(T.parameters))) - else - isempty(T.parameters) - end -first_param(::Type{T}) where {T} = - if @generated - :($(first(T.parameters))) - else - first(T.parameters) - end -tail_params(::Type{T}) where {T} = - if @generated - :($(Tuple{Base.tail((T.parameters...,))...})) - else - Tuple{Base.tail((T.parameters...,))...} - end - -# This is a type-stable version of map(x -> rmap(fn, x), X) or -# map((x, y) -> rmap(fn, x, y), X, Y). -rmap_tuple(fn::F, X) where {F} = - isempty(X) ? () : (rmap(fn, first(X)), rmap_tuple(fn, Base.tail(X))...) -rmap_tuple(fn::F, X, Y) where {F} = - isempty(X) || isempty(Y) ? () : - ( - rmap(fn, first(X), first(Y)), - rmap_tuple(fn, Base.tail(X), Base.tail(Y))..., - ) - -# This is a type-stable version of map(T′ -> rfunc(fn, T′), T.parameters) or -# map((T1′, T2′) -> rfunc(fn, T1′, T2′), T1.parameters, T2.parameters), where -# rfunc can be either rmaptype or rmap_type2value. -rmap_Tuple(rfunc::R, fn::F, ::Type{T}) where {R, F, T} = - has_no_params(T) ? () : - (rfunc(fn, first_param(T)), rmap_Tuple(rfunc, fn, tail_params(T))...) -rmap_Tuple(rfunc::R, fn::F, ::Type{T1}, ::Type{T2}) where {R, F, T1, T2} = - has_no_params(T1) || has_no_params(T2) ? () : - ( - rfunc(fn, first_param(T1), first_param(T2)), - rmap_Tuple(rfunc, fn, tail_params(T1), tail_params(T2))..., - ) +@generated first_param(::Type{T}) where {T} = :($(first(T.parameters))) +@generated tail_params(::Type{T}) where {T} = + :($(Tuple{Base.tail((T.parameters...,))...})) + +# Applying `rmaptype` returns `Tuple{...}` for tuple +# types, which cannot follow the recursion pattern as +# it cannot be splatted, so we add a separate method, +# `rmaptype_Tuple`, for the part of the recursion. +rmaptype_Tuple(fn::F, ::Type{Tuple{}}) where {F} = () +rmaptype_Tuple(fn::F, ::Type{T}) where {F, E, T <: Tuple{E}} = + (rmaptype(fn, first_param(T)),) +rmaptype_Tuple(fn::F, ::Type{T}) where {F, T <: Tuple} = + (rmaptype(fn, first_param(T)), rmaptype_Tuple(fn, tail_params(T))...) + +rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{Tuple{}}) = () +rmaptype_Tuple(_, ::Type{Tuple{}}, ::Type{T}) where {T <: Tuple} = () +rmaptype_Tuple(_, ::Type{T}, ::Type{Tuple{}}) where {T <: Tuple} = () +rmaptype_Tuple( + fn::F, + ::Type{T1}, + ::Type{T2}, +) where {F, T1 <: Tuple, T2 <: Tuple} = ( + rmaptype(fn, first_param(T1), first_param(T2)), + rmaptype_Tuple(fn, tail_params(T1), tail_params(T2))..., +) """ rmap(fn, X...) @@ -60,11 +43,18 @@ rmap_Tuple(rfunc::R, fn::F, ::Type{T1}, ::Type{T2}) where {R, F, T1, T2} = Recursively apply `fn` to each element of `X` """ rmap(fn::F, X) where {F} = fn(X) -rmap(fn::F, X, Y) where {F} = fn(X, Y) -rmap(fn::F, X::Tuple) where {F} = rmap_tuple(fn, X) -rmap(fn::F, X::Tuple, Y::Tuple) where {F} = rmap_tuple(fn, X, Y) +rmap(fn::F, X::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple) where {F} = + (rmap(fn, first(X)), rmap(fn, Base.tail(X))...) rmap(fn::F, X::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X))) + +rmap(fn::F, X, Y) where {F} = fn(X, Y) +rmap(fn::F, X::Tuple{}, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple{}, Y::Tuple) where {F} = () +rmap(fn::F, X::Tuple, Y::Tuple{}) where {F} = () +rmap(fn::F, X::Tuple, Y::Tuple) where {F} = + (rmap(fn, first(X), first(Y)), rmap(fn, Base.tail(X), Base.tail(Y))...) rmap(fn::F, X::NamedTuple{names}, Y::NamedTuple{names}) where {F, names} = NamedTuple{names}(rmap(fn, Tuple(X), Tuple(Y))) @@ -81,13 +71,14 @@ Recursively apply `fn` to each type parameter of the type `T`, or to each type parameter of the types `T1` and `T2`, where `fn` returns a type. """ rmaptype(fn::F, ::Type{T}) where {F, T} = fn(T) -rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) rmaptype(fn::F, ::Type{T}) where {F, T <: Tuple} = - Tuple{rmap_Tuple(rmaptype, fn, T)...} -rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = - Tuple{rmap_Tuple(rmaptype, fn, T1, T2)...} + Tuple{rmaptype_Tuple(fn, T)...} rmaptype(fn::F, ::Type{T}) where {F, names, Tup, T <: NamedTuple{names, Tup}} = NamedTuple{names, rmaptype(fn, Tup)} + +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1, T2} = fn(T1, T2) +rmaptype(fn::F, ::Type{T1}, ::Type{T2}) where {F, T1 <: Tuple, T2 <: Tuple} = + Tuple{rmaptype_Tuple(fn, T1, T2)...} rmaptype( fn::F, ::Type{T1}, @@ -101,27 +92,18 @@ rmaptype( T2 <: NamedTuple{names, Tup2}, } = NamedTuple{names, rmaptype(fn, Tup1, Tup2)} -""" - rmap_type2value(fn, T) - -Recursively apply `fn` to each type parameter of the type `T`, where `fn` -returns a value instead of a type. -""" -rmap_type2value(fn::F, ::Type{T}) where {F, T} = fn(T) -rmap_type2value(fn::F, ::Type{T}) where {F, T <: Tuple} = - rmap_Tuple(rmap_type2value, fn, T) -rmap_type2value( - fn::F, - ::Type{T}, -) where {F, names, Tup, T <: NamedTuple{names, Tup}} = - NamedTuple{names}(rmap_type2value(fn, Tup)) - """ rzero(T) Recursively compute the zero value of type `T`. """ -rzero(::Type{T}) where {T} = rmap_type2value(zero, T) +rzero(::Type{T}) where {T} = zero(T) +rzero(::Type{Tuple{}}) = () +rzero(::Type{T}) where {E, T <: Tuple{E}} = (rzero(E),) +rzero(::Type{T}) where {T <: Tuple} = + (rzero(first_param(T)), rzero(tail_params(T))...) +rzero(::Type{Tup}) where {names, T, Tup <: NamedTuple{names, T}} = + NamedTuple{names}(rzero(T)) """ rmul(X, Y) diff --git a/test/MatrixFields/matrix_field_broadcasting.jl b/test/MatrixFields/matrix_field_broadcasting.jl index f1572a7ed7..198a87af9e 100644 --- a/test/MatrixFields/matrix_field_broadcasting.jl +++ b/test/MatrixFields/matrix_field_broadcasting.jl @@ -4,16 +4,9 @@ using Random: seed! using LinearAlgebra: I, mul! using BandedMatrices: band +using ClimaCore.MatrixFields import ClimaCore: Geometry, Domains, Meshes, Topologies, Hypsography, Spaces, Fields -import ClimaCore.MatrixFields: - DiagonalMatrixRow, - BidiagonalMatrixRow, - TridiagonalMatrixRow, - QuaddiagonalMatrixRow, - field2arrays, - ⋅ -import ClimaCore.Utilities: PlusHalf, half import ClimaComms # Using @benchmark from BenchmarkTools is extremely slow; it appears to keep @@ -56,7 +49,7 @@ function test_func_against_array_reference(; temp_values, func!::F1, ref_func!::F2, - print_summary = false, + print_summary = true, ) where {F1, F2} @testset "$test_name" begin # Fill all output fields with NaNs for testing. @@ -65,9 +58,9 @@ function test_func_against_array_reference(; temp_value .*= NaN end - ref_result_arrays = field2arrays(result) - inputs_arrays = map(field2arrays, inputs) - temp_values_arrays = map(field2arrays, temp_values) + ref_result_arrays = MatrixFields.field2arrays(result) + inputs_arrays = map(MatrixFields.field2arrays, inputs) + temp_values_arrays = map(MatrixFields.field2arrays, temp_values) time, allocs = @benchmark call_func(func!, result, inputs) ref_time, ref_allocs = @benchmark call_array_func( @@ -78,7 +71,7 @@ function test_func_against_array_reference(; ) # Compute the maximum error as an integer multiple of machine epsilon. - result_arrays = field2arrays(result) + result_arrays = MatrixFields.field2arrays(result) maximum_error = maximum(zip(result_arrays, ref_result_arrays)) do (array, ref_array) maximum(zip(array, ref_array)) do (value, ref_value) @@ -122,7 +115,7 @@ function test_func_against_reference(; ref_inputs, func!::F1, ref_func!::F2, - print_summary = false, + print_summary = true, ) where {F1, F2} @testset "$test_name" begin result .*= NaN