Skip to content

Commit

Permalink
Add MultiBroadcastFusion implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Apr 12, 2024
1 parent 1d5bce6 commit d080ff1
Show file tree
Hide file tree
Showing 7 changed files with 543 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ module DataLayouts
import Base: Base, @propagate_inbounds
import StaticArrays: SOneTo, MArray, SArray
import ClimaComms
import MultiBroadcastFusion as MBF

import ..slab, ..slab_args, ..column, ..column_args, ..level
export slab, column, level, IJFH, IJF, IFH, IF, VF, VIJFH, VIFH, DataF
Expand Down
99 changes: 99 additions & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import MultiBroadcastFusion as MBF

# Make a MultiBroadcastFusion type, `FusedMultiBroadcast`, and macro, `@fused`:
# via https://github.com/CliMA/MultiBroadcastFusion.jl
MBF.@make_fused FusedMultiBroadcast fused

# Broadcasting of AbstractData objects
# https://docs.julialang.org/en/v1/manual/interfaces/#Broadcast-Styles

Expand Down Expand Up @@ -587,3 +593,96 @@ function Base.copyto!(
) where {S, Nij, A}
return _serial_copyto!(dest, bc)
end

# ============= FusedMultiBroadcast

isascalar(
bc::Base.Broadcast.Broadcasted{Style},
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
} = true
isascalar(bc) = false

# Fused multi-broadcast entry point for DataLayouts
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:AbstractData, <:Any}}}
dest1 = first(fmbc.pairs).first
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
if parent_array_type(typeof(parent(dest1))) <: CUDA.CuArray
fused_copyto_cuda!(fmbc, dest1)
else
fused_copyto_cpu!(fmbc, dest1)
end
end

function fused_copyto_cpu!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S1, Nij},
) where {S1, Nij}
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
# Base.copyto!(dest, bc) # we can just fall back like this
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv
I = CartesianIndex(i, j, 1, v, h)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function fused_copyto_cpu!(
fmbc::FusedMultiBroadcast,
dest1::VIFH{S, Ni, A},
) where {S, Ni, A}
# copy contiguous columns
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for h in 1:Nh, i in 1:Ni, v in 1:Nv
I = CartesianIndex(i, 1, 1, v, h)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = convert(eltype(dest), bcI)
end
end
return nothing
end

function fused_copyto_cpu!(
fmbc::FusedMultiBroadcast,
dest1::VF{S1, A},
) where {S1, A}
_, _, _, Nv, _ = size(dest1)
for (dest, bc) in fmbc.pairs
@inbounds for v in 1:Nv
I = CartesianIndex(1, 1, 1, v, 1)
dest[I] = convert(eltype(dest), bc[I])
end
end
return nothing
end

# we've already diagonalized dest, so we only need to make
# sure that all the broadcast axes are compatible.
# Logic here is similar to Base.Broadcast.instantiate
@inline function _check_fused_broadcast_axes(bc1, bc2)
axes = Base.Broadcast.combine_axes(bc1.args..., bc2.args...)
if !(axes isa Nothing)
Base.Broadcast.check_broadcast_axes(axes, bc1.args...)
Base.Broadcast.check_broadcast_axes(axes, bc2.args...)
end
end

@inline check_fused_broadcast_axes(fmbc::FusedMultiBroadcast) =
check_fused_broadcast_axes(
map(x -> x.second, fmbc.pairs),
first(fmbc.pairs).second,
)
@inline check_fused_broadcast_axes(bcs::Tuple{<:Any}, bc1) =
_check_fused_broadcast_axes(first(bcs), bc1)
@inline check_fused_broadcast_axes(bcs::Tuple{}, bc1) = nothing
@inline function check_fused_broadcast_axes(bcs::Tuple, bc1)
_check_fused_broadcast_axes(first(bcs), bc1)
check_fused_broadcast_axes(Base.tail(bcs), bc1)
end
73 changes: 73 additions & 0 deletions src/DataLayouts/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,76 @@ function Base.fill!(dest::DataF{S, A}, val) where {S, A <: CUDA.CuArray}
)
return dest
end

Base.@propagate_inbounds function rcopyto_at!(
pair::Pair{<:AbstractData, <:Any},
I,
v,
)
dest, bc = pair.first, pair.second
if v <= size(dest, 4)
bcI = isascalar(bc) ? bc[] : bc[I]
dest[I] = bcI
end
return nothing
end
Base.@propagate_inbounds function rcopyto_at!(pairs::Tuple, I, v)
rcopyto_at!(first(pairs), I, v)
rcopyto_at!(Base.tail(pairs), I, v)
end
Base.@propagate_inbounds rcopyto_at!(pairs::Tuple{<:Any}, I, v) =
rcopyto_at!(first(pairs), I, v)
@inline rcopyto_at!(pairs::Tuple{}, I, v) = nothing

function knl_fused_copyto!(fmbc::FusedMultiBroadcast)

@inbounds begin
i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
(; pairs) = fmbc
I = CartesianIndex((i, j, 1, v, h))
rcopyto_at!(pairs, I, v)
end
return nothing
end

function fused_copyto_cuda!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nij},
) where {S, Nij}
_, _, _, Nv, Nh = size(dest1)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Nv_blocks = cld(Nv, Nv_per_block)
CUDA.@cuda always_inline = true threads = (Nij, Nij, Nv_per_block) blocks =
(Nh, Nv_blocks) knl_fused_copyto!(fmbc)
end
return nothing
end

adapt_f(to, f::F) where {F} = Adapt.adapt(to, f)
adapt_f(to, ::Type{F}) where {F} = (x...) -> F(x...)

function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
fmbc::FusedMultiBroadcast,
)
FusedMultiBroadcast(
map(fmbc.pairs) do pair
dest = pair.first
bc = pair.second
Pair(
Adapt.adapt(to, dest),
Base.Broadcast.Broadcasted(
bc.style,
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
),
)
end,
)
end
9 changes: 8 additions & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@ module Fields
import ClimaComms
import MultiBroadcastFusion as MBF
import ..slab, ..slab_args, ..column, ..column_args, ..level
import ..DataLayouts: DataLayouts, AbstractData, DataStyle
import ..DataLayouts:
DataLayouts,
AbstractData,
DataStyle,
FusedMultiBroadcast,
@fused,
isascalar,
check_fused_broadcast_axes
import ..Domains
import ..Topologies
import ..Quadratures
Expand Down
39 changes: 39 additions & 0 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,45 @@ end
return dest
end

# Fused multi-broadcast entry point for Fields
function Base.copyto!(
fmbc::FusedMultiBroadcast{T},
) where {N, T <: NTuple{N, Pair{<:Field, <:Any}}}
fmb_data = FusedMultiBroadcast(
map(fmbc.pairs) do pair
bc = Base.Broadcast.instantiate(todata(pair.second))
bc′ = if isascalar(bc)
Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, ()),
)
else
bc
end
Pair(field_values(pair.first), bc′)
end,
)
check_mismatched_spaces(fmbc)
check_fused_broadcast_axes(fmbc)
Base.copyto!(fmb_data) # forward to DataLayouts
end

@inline check_mismatched_spaces(fmbc::FusedMultiBroadcast) =
check_mismatched_spaces(
map(x -> axes(x.first), fmbc.pairs),
axes(first(fmbc.pairs).first),
)
@inline check_mismatched_spaces(axs::Tuple{<:Any}, ax1) =
_check_mismatched_spaces(first(axs), ax1)
@inline check_mismatched_spaces(axs::Tuple{}, ax1) = nothing
@inline function check_mismatched_spaces(axs::Tuple, ax1)
_check_mismatched_spaces(first(axs), ax1)
check_mismatched_spaces(Base.tail(axs), ax1)
end

_check_mismatched_spaces(::T, ::T) where {T <: AbstractSpace} = nothing
_check_mismatched_spaces(space1, space2) =
error("FusedMultiBroadcast spaces are not the same.")

@noinline function error_mismatched_spaces(space1::Type, space2::Type)
error("Broacasted spaces are not the same.")
end
Expand Down
17 changes: 13 additions & 4 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#=
julia --check-bounds=yes --project=test
julia --project=test
using Revise; include(joinpath("test", "Fields", "field.jl"))
=#
using Test
using JET
using BenchmarkTools

using ClimaComms
using OrderedCollections
Expand All @@ -22,16 +24,19 @@ import ClimaCore:
Geometry,
Quadratures

import ClimaCore.Fields: @fused
using LinearAlgebra: norm
using Statistics: mean
using ForwardDiff
using CUDA
using CUDA: @allowscalar

include(
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
)
import .TestUtilities as TU
util_file =
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl")
if !(@isdefined(TU))
include(util_file)
import .TestUtilities as TU
end

function spectral_space_2D(; n1 = 1, n2 = 1, Nij = 4)
domain = Domains.RectangleDomain(
Expand Down Expand Up @@ -915,3 +920,7 @@ end
end
nothing
end

include("field_multi_broadcast_fusion.jl")

nothing
Loading

0 comments on commit d080ff1

Please sign in to comment.