Skip to content

Commit

Permalink
Add MultiBroadcastFusion implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Mar 12, 2024
1 parent 8b73a3d commit 56809bf
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ IntervalSets = "0.5, 0.6, 0.7"
Krylov = "0.9"
KrylovKit = "0.6"
LinearAlgebra = "1"
MultiBroadcastFusion = "0.1"
PkgVersion = "0.1, 0.2, 0.3"
RecursiveArrayTools = "2, 3"
RootSolvers = "0.3, 0.4"
Expand Down
1 change: 1 addition & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module DataLayouts
import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms
import MultiBroadcastFusion as MBF

import ..slab, ..slab_args, ..column, ..column_args, ..level
export slab, column, level, IJFH, IJF, IFH, IF, VF, VIJFH, VIFH, DataF
Expand Down
102 changes: 102 additions & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import MultiBroadcastFusion as MBF

# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
# via https://github.com/CliMA/MultiBroadcastFusion.jl
MBF.@make_fused FusedMultiBroadcast fused

# Broadcasting of AbstractData objects
# https://docs.julialang.org/en/v1/manual/interfaces/#Broadcast-Styles

Expand Down Expand Up @@ -587,3 +593,99 @@ function Base.copyto!(
) where {S, Nij, A}
return _serial_copyto!(dest, bc)
end

# ============= FusedMultiBroadcast

isascalar(
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
} = true
isascalar(bc) = false
function get_bc(bc, I)
if isascalar(bc)
# bc′ = Base.Broadcast.instantiate(
# Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
# )
# bcI = bc′[]
bcI = bc[]
else
bcI = bc[I]
end
return bcI
end

function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}}
dest1 = first(fmbc.pairs).first
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
Base.copyto!(fmbc, dest1)
end

function Base.copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S1, Nij, A},
) where {S1, Nij, A}
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
# Base.copyto!(dest, bc) # we can just fall back like this
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv
I = CartesianIndex(i, j, 1, v, h)
bcI = get_bc(bc, I)
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function Base.copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIFH{S, Ni, A},
) where {S, Ni, A}
# copy contiguous columns
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for h in 1:Nh, i in 1:Nij, v in 1:Nv
I = CartesianIndex(i, 1, v, h)
dest[I] = convert(eltype(dest), bc[I])
end
end
return nothing
end

function Base.copyto!(fmbc::FusedMultiBroadcast, dest1::VF{S1, A}) where {S1, A}
_, _, _, Nv, _ = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for v in 1:Nv
I = CartesianIndex(1, 1, 1, v, 1)
dest[I] = convert(eltype(dest), bc[I])
end
end
return nothing
end

# we've already diagonalized dest, so we only need to make
# sure that all the broadcast axes are compatible.
# Logic here is similar to Base.Broadcast.instantiate
@inline function _check_fused_broadcast_axes(bc1, bc2)
axes = Base.Broadcast.combine_axes(bc1.args..., bc2.args...)
if !(axes isa Nothing)
Base.Broadcast.check_broadcast_axes(axes, bc1.args...)
Base.Broadcast.check_broadcast_axes(axes, bc2.args...)
end
end

@inline check_fused_broadcast_axes(fmbc::FusedMultiBroadcast) =
check_fused_broadcast_axes(
map(x -> x.second, fmbc.pairs),
first(fmbc.pairs).second,
)
@inline check_fused_broadcast_axes(bcs::Tuple{<:Any}, bc1) =
_check_fused_broadcast_axes(first(bcs), bc1)
@inline check_fused_broadcast_axes(bcs::Tuple{}, bc1) = nothing
@inline function check_fused_broadcast_axes(bcs::Tuple, bc1)
_check_fused_broadcast_axes(first(bcs), bc1)
check_fused_broadcast_axes(Base.tail(bcs), bc1)
end
73 changes: 73 additions & 0 deletions src/DataLayouts/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,76 @@ function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray}
CUDA.@cuda threads = (1, 1) blocks = (1, 1) knl_fill!(dest, val)
return dest
end

Base.@propagate_inbounds function rcopyto_at!(
pair::Pair{<:AbstractData, <:Any},
I,
v,
)
dest, bc = pair.first, pair.second
if v <= size(dest, 4)
bcI = get_bc(bc, I)
dest[I] = bcI
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, v)
rcopyto_at!(first(pairs), I, v)
rcopyto_at!(Base.tail(pairs), I, v)
end
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, v) =
rcopyto_at!(first(pairs), I, v)
@inline rcopyto_at!(pairs::Tuple{}, I, v) = nothing

function knl_fused_copyto!(fmbc::FusedMultiBroadcast)

@inbounds begin
i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
(; pairs) = fmbc
I = CartesianIndex((i, j, 1, v, h))
rcopyto_at!(pairs, I, v)
end
return nothing
end

function Base.copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nij, A},
) where {S, Nij, A <: CUDA.CuArray}
_, _, _, Nv, Nh = size(dest1)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Nv_blocks = cld(Nv, Nv_per_block)
CUDA.@cuda always_inline = true threads = (Nij, Nij, Nv_per_block) blocks =
(Nh, Nv_blocks) knl_fused_copyto!(fmbc)
end
return nothing
end

adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
fmbc::FusedMultiBroadcast,
)
FusedMultiBroadcast(
map(fmbc.pairs) do pair
dest = pair.first
bc = pair.second
Pair(
Adapt.adapt(to, dest),
Base.Broadcast.Broadcasted(
bc.style,
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
),
)
end,
)
end
9 changes: 8 additions & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@ module Fields

import ClimaComms
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts: DataLayouts, AbstractData, DataStyle
import ..DataLayouts:
DataLayouts,
AbstractData,
DataStyle,
FusedMultiBroadcast,
@fused,
isascalar,
check_fused_broadcast_axes
import ..Domains
import ..Topologies
import ..Quadratures
Expand Down
40 changes: 40 additions & 0 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,46 @@ end
return dest
end

# Fused multi-broadcast
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}}
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(field_values(pair.first), bc′)
end,
)
check_mismatched_spaces(fmbc)
check_fused_broadcast_axes(fmbc)
dest1 = first(fmb_data.pairs).first
Base.copyto!(fmb_data, dest1)
end

@inline check_mismatched_spaces(fmbc::FusedMultiBroadcast) =
check_mismatched_spaces(
map(x -> axes(x.first), fmbc.pairs),
axes(first(fmbc.pairs).first),
)
@inline check_mismatched_spaces(axs::Tuple{<:Any}, ax1) =
_check_mismatched_spaces(first(axs), ax1)
@inline check_mismatched_spaces(axs::Tuple{}, ax1) = nothing
@inline function check_mismatched_spaces(axs::Tuple, ax1)
_check_mismatched_spaces(first(axs), ax1)
check_mismatched_spaces(Base.tail(axs), ax1)
end

_check_mismatched_spaces(::T, ::T) where {T <: AbstractSpace} = nothing
_check_mismatched_spaces(space1, space2) =
error("FusedMultiBroadcast spaces are not the same.")

@noinline function error_mismatched_spaces(space1::Type, space2::Type)
error("Broacasted spaces are not the same.")
end
Expand Down
Loading

0 comments on commit 56809bf

Please sign in to comment.