Skip to content

Commit

Permalink
Add MultiBroadcastFusion implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 3, 2024
1 parent c038ea9 commit 2ff71eb
Show file tree
Hide file tree
Showing 7 changed files with 557 additions and 1 deletion.
86 changes: 86 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@

import ClimaCore.DataLayouts: AbstractData
import ClimaCore.DataLayouts: FusedMultiBroadcast
import ClimaCore.DataLayouts: IJKFVH, IJFH, VIJFH, VIFH, IFH, IJF, IF, VF, DataF
import ClimaCore.DataLayouts: IJFHStyle, VIJFHStyle, VFStyle, DataFStyle
import ClimaCore.DataLayouts: promote_parent_array_type
import ClimaCore.DataLayouts: parent_array_type
import ClimaCore.DataLayouts: device_from_array_type, isascalar
import ClimaCore.DataLayouts: fused_copyto!
import Adapt
import CUDA

device_from_array_type(::Type{<:CUDA.CuArray}) = ClimaComms.CUDADevice()

parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
CUDA.CuArray{T, N, B} where {N}

Expand Down Expand Up @@ -180,3 +186,83 @@ function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray}
)
return dest
end

Base.@propagate_inbounds function rcopyto_at!(
pair::Pair{<:AbstractData, <:Any},
I,
v,
)
dest, bc = pair.first, pair.second
if v <= size(dest, 4)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = bcI
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, v)
rcopyto_at!(first(pairs), I, v)
rcopyto_at!(Base.tail(pairs), I, v)
end
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, v) =
rcopyto_at!(first(pairs), I, v)
@inline rcopyto_at!(pairs::Tuple{}, I, v) = nothing

function knl_fused_copyto!(fmbc::FusedMultiBroadcast)

@inbounds begin
i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
(; pairs) = fmbc
I = CartesianIndex((i, j, 1, v, h))
rcopyto_at!(pairs, I, v)
end
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nij},
::ClimaComms.CUDADevice,
) where {S, Nij}
_, _, _, Nv, Nh = size(dest1)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Nv_blocks = cld(Nv, Nv_per_block)
args = (fmbc,)
auto_launch!(
knl_fused_copyto!,
args,
dest1;
threads_s = (Nij, Nij, Nv_per_block),
blocks_s = (Nh, Nv_blocks),
)
end
return nothing
end

adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
fmbc::FusedMultiBroadcast,
)
FusedMultiBroadcast(
map(fmbc.pairs) do pair
dest = pair.first
bc = pair.second
Pair(
Adapt.adapt(to, dest),
Base.Broadcast.Broadcasted(
bc.style,
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
),
)
end,
)
end
8 changes: 8 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module DataLayouts
import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms
import MultiBroadcastFusion as MBF
import Adapt

import ..slab, ..slab_args, ..column, ..column_args, ..level
Expand Down Expand Up @@ -1451,4 +1452,11 @@ Adapt.adapt_structure(to, data::VF{S}) where {S} =
Adapt.adapt_structure(to, data::DataF{S}) where {S} =
DataF{S}(Adapt.adapt(to, parent(data)))

# TODO: Should the DataLayout be device-aware? So that we can
# determine if we're multi-threaded or not?
# This is only currently used in FusedMultiBroadcast kernels
device_from_array_type(::Type{<:AbstractArray}) = ClimaComms.CPUSingleThreaded()
ClimaComms.device(data::AbstractData) =
device_from_array_type(typeof(parent(data)))

end # module
101 changes: 101 additions & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import MultiBroadcastFusion as MBF
import MultiBroadcastFusion: fused_direct

# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
# via https://github.com/CliMA/MultiBroadcastFusion.jl
MBF.@make_type FusedMultiBroadcast
MBF.@make_fused fused_direct FusedMultiBroadcast fused_direct

# Broadcasting of AbstractData objects
# https://docs.julialang.org/en/v1/manual/interfaces/#Broadcast-Styles

Expand Down Expand Up @@ -587,3 +595,96 @@ function Base.copyto!(
) where {S, Nij, A}
return _serial_copyto!(dest, bc)
end

# ============= FusedMultiBroadcast

isascalar(
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
} = true
isascalar(bc) = false


# Fused multi-broadcast entry point for DataLayouts
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}}
dest1 = first(fmbc.pairs).first
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmbc, dest1, ClimaComms.device(dest1))
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S1, Nij},
::ClimaComms.AbstractCPUDevice,
) where {S1, Nij}
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
# Base.copyto!(dest, bc) # we can just fall back like this
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv
I = CartesianIndex(i, j, 1, v, h)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIFH{S, Ni, A},
::ClimaComms.AbstractCPUDevice,
) where {S, Ni, A}
# copy contiguous columns
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for h in 1:Nh, i in 1:Ni, v in 1:Nv
I = CartesianIndex(i, 1, 1, v, h)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S1, A},
::ClimaComms.AbstractCPUDevice,
) where {S1, A}
_, _, _, Nv, _ = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for v in 1:Nv
I = CartesianIndex(1, 1, 1, v, 1)
dest[I] = convert(eltype(dest), bc[I])
end
end
return nothing
end

# we've already diagonalized dest, so we only need to make
# sure that all the broadcast axes are compatible.
# Logic here is similar to Base.Broadcast.instantiate
@inline function _check_fused_broadcast_axes(bc1, bc2)
axes = Base.Broadcast.combine_axes(bc1.args..., bc2.args...)
if !(axes isa Nothing)
Base.Broadcast.check_broadcast_axes(axes, bc1.args...)
Base.Broadcast.check_broadcast_axes(axes, bc2.args...)
end
end

@inline check_fused_broadcast_axes(fmbc::FusedMultiBroadcast) =
check_fused_broadcast_axes(
map(x -> x.second, fmbc.pairs),
first(fmbc.pairs).second,
)
@inline check_fused_broadcast_axes(bcs::Tuple{<:Any}, bc1) =
_check_fused_broadcast_axes(first(bcs), bc1)
@inline check_fused_broadcast_axes(bcs::Tuple{}, bc1) = nothing
@inline function check_fused_broadcast_axes(bcs::Tuple, bc1)
_check_fused_broadcast_axes(first(bcs), bc1)
check_fused_broadcast_axes(Base.tail(bcs), bc1)
end
9 changes: 8 additions & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ module Fields
import ClimaComms
import MultiBroadcastFusion as MBF
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts: DataLayouts, AbstractData, DataStyle
import ..DataLayouts:
DataLayouts,
AbstractData,
DataStyle,
FusedMultiBroadcast,
@fused_direct,
isascalar,
check_fused_broadcast_axes
import ..Domains
import ..Topologies
import ..Quadratures
Expand Down
39 changes: 39 additions & 0 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,45 @@ end
return dest
end

# Fused multi-broadcast entry point for Fields
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}}
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(field_values(pair.first), bc′)
end,
)
check_mismatched_spaces(fmbc)
check_fused_broadcast_axes(fmbc)
Base.copyto!(fmb_data) # forward to DataLayouts
end

@inline check_mismatched_spaces(fmbc::FusedMultiBroadcast) =
check_mismatched_spaces(
map(x -> axes(x.first), fmbc.pairs),
axes(first(fmbc.pairs).first),
)
@inline check_mismatched_spaces(axs::Tuple{<:Any}, ax1) =
_check_mismatched_spaces(first(axs), ax1)
@inline check_mismatched_spaces(axs::Tuple{}, ax1) = nothing
@inline function check_mismatched_spaces(axs::Tuple, ax1)
_check_mismatched_spaces(first(axs), ax1)
check_mismatched_spaces(Base.tail(axs), ax1)
end

_check_mismatched_spaces(::T, ::T) where {T <: AbstractSpace} = nothing
_check_mismatched_spaces(space1, space2) =
error("FusedMultiBroadcast spaces are not the same.")

@noinline function error_mismatched_spaces(space1::Type, space2::Type)
error("Broacasted spaces are not the same.")
end
Expand Down
5 changes: 5 additions & 0 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#=
julia --check-bounds=yes --project=test
julia --project=test
using Revise; include(joinpath("test", "Fields", "field.jl"))
=#
Expand Down Expand Up @@ -915,3 +916,7 @@ end
end
nothing
end

include("field_multi_broadcast_fusion.jl")

nothing
Loading

0 comments on commit 2ff71eb

Please sign in to comment.