Skip to content

Commit

Permalink
Add linear index support for pointwise kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 9, 2024
1 parent 3bc75d1 commit 378bbcf
Show file tree
Hide file tree
Showing 14 changed files with 730 additions and 68 deletions.
13 changes: 13 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,16 @@ function Adapt.adapt_structure(
end,
)
end

import Adapt
import CUDA
function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
bc::DataLayouts.NonExtrudedBroadcasted{Style},
) where {Style}
DataLayouts.NonExtrudedBroadcasted{Style}(
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
)
end
23 changes: 23 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,27 @@ empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
@inline get_Nij(::IJF{S, Nij}) where {S, Nij} = Nij
@inline get_Nij(::IF{S, Nij}) where {S, Nij} = Nij

# Returns the size of the backing array.
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, 1, Nv, Nh)
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
@inline array_size(::DataF{S}) where {S} = (1,)
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, 1, Nh)
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, 1, Nh)

@inline farray_size(data::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh)
@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, ncomponents(data), Nh)
@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, ncomponents(data), Nh)
@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),)
@inline farray_size(data::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, ncomponents(data))
@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data))
@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data))
@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, ncomponents(data), Nh)
@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, ncomponents(data), Nh)

"""
field_dim(data::AbstractData)
field_dim(::Type{<:AbstractData})
Expand Down Expand Up @@ -1216,9 +1237,11 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
_device_dispatch(x::SArray) = ToCPU()
_device_dispatch(x::MArray) = ToCPU()

include("non_extruded_broadcasted.jl")
include("copyto.jl")
include("fused_copyto.jl")
include("fill.jl")
include("mapreduce.jl")
include("has_uniform_datalayouts.jl")

end # module
1 change: 1 addition & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} =
#####

#! format: off
const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData}
const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}}
const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}}
const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}}
Expand Down
18 changes: 15 additions & 3 deletions src/DataLayouts/copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@
##### Dispatching and edge cases
#####

Base.copyto!(
dest::AbstractData,
function Base.copyto!(
dest::AbstractData{S},
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) = Base.copyto!(dest, bc, device_dispatch(dest))
) where {S}
dev = device_dispatch(dest)
if dev isa ToCPU && has_uniform_datalayouts(bc) && !(dest isa DataF)
# Specialize on linear indexing case:
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = convert(S, bc′[I])
end
else
Base.copyto!(dest, bc, device_dispatch(dest))
end
return dest
end

# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
Expand Down
61 changes: 7 additions & 54 deletions src/DataLayouts/fill.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,13 @@
function Base.fill!(data::IJFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
function Base.fill!(dest::AbstractData, val, ::ToCPU)
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = val
end
return data
return dest
end

function Base.fill!(data::IFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
end
return data
end

function Base.fill!(data::DataF, val, ::ToCPU)
@inbounds data[] = val
return data
end

function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
data[CartesianIndex(i, j, 1, 1, 1)] = val
end
return data
end

function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
@inbounds for i in 1:Ni
data[CartesianIndex(i, 1, 1, 1, 1)] = val
end
return data
end

function Base.fill!(data::VF, val, ::ToCPU)
Nv = nlevels(data)
@inbounds for v in 1:Nv
data[CartesianIndex(1, 1, 1, v, 1)] = val
end
return data
end

function Base.fill!(data::VIJFH, val, ::ToCPU)
(Ni, Nj, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
end
return data
end

function Base.fill!(data::VIFH, val, ::ToCPU)
(Ni, _, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
end
return data
function Base.fill!(dest::DataF, val, ::ToCPU)
@inbounds dest[] = val
return dest
end

Base.fill!(dest::AbstractData, val) =
Expand Down
60 changes: 60 additions & 0 deletions src/DataLayouts/has_uniform_datalayouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
x1 = first_datalayout_in_bc(args[1], rargs...)
x1 isa AbstractData && return x1
return first_datalayout_in_bc(Base.tail(args), rargs...)
end

@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
first_datalayout_in_bc(args[1], rargs...)
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_datalayout_in_bc(x) = nothing
@inline first_datalayout_in_bc(x::AbstractData) = x

@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
first_datalayout_in_bc(bc.args)

@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
truesofar &&
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)

@inline _has_uniform_datalayouts_args(
truesofar,
start,
args::Tuple{Any},
rargs...,
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
truesofar

@inline function _has_uniform_datalayouts(
truesofar,
start,
bc::Base.Broadcast.Broadcasted,
)
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
end
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
@eval begin
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
end
end
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar

"""
has_uniform_datalayouts
Find the first datalayout in the broadcast expression (BCE),
and compares against every other datalayout in the BCE. Returns
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
Note: a broadcasted object can have different _types_,
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
"""
function has_uniform_datalayouts end

@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)

@inline has_uniform_datalayouts(bc::AbstractData) = true
160 changes: 160 additions & 0 deletions src/DataLayouts/non_extruded_broadcasted.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#! format: off
# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4)
import Base.Broadcast: BroadcastStyle
struct NonExtrudedBroadcasted{
Style <: Union{Nothing, BroadcastStyle},
Axes,
F,
Args <: Tuple,
} <: Base.AbstractBroadcasted
style::Style
f::F
args::Args
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `NonExtrudedBroadcasted`)

NonExtrudedBroadcasted(style::Union{Nothing, BroadcastStyle}, f::Tuple, args::Tuple) =
error() # disambiguation: tuple is not callable
function NonExtrudedBroadcasted(
style::Union{Nothing, BroadcastStyle},
f::F,
args::Tuple,
axes = nothing,
) where {F}
# using Core.Typeof rather than F preserves inferrability when f is a type
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(
style,
f,
args,
axes,
)
end
function NonExtrudedBroadcasted(f::F, args::Tuple, axes = nothing) where {F}
NonExtrudedBroadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
end
function NonExtrudedBroadcasted{Style}(f::F, args, axes = nothing) where {Style, F}
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(
Style()::Style,
f,
args,
axes,
)
end
function NonExtrudedBroadcasted{Style, Axes, F, Args}(
f,
args,
axes,
) where {Style, Axes, F, Args}
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
end
end

@inline to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted) =
NonExtrudedBroadcasted(bc.style, bc.f, to_non_extruded_broadcasted(bc.args), bc.axes)
@inline to_non_extruded_broadcasted(x) = x
NonExtrudedBroadcasted(bc::Base.Broadcast.Broadcasted) = to_non_extruded_broadcasted(bc)

@inline to_non_extruded_broadcasted(args::Tuple) = (
to_non_extruded_broadcasted(args[1]),
to_non_extruded_broadcasted(Base.tail(args))...,
)
@inline to_non_extruded_broadcasted(args::Tuple{Any}) =
(to_non_extruded_broadcasted(args[1]),)
@inline to_non_extruded_broadcasted(args::Tuple{}) = ()

@inline _checkbounds(bc, _, I) = nothing # TODO: fix this case
@inline _checkbounds(bc, ::Tuple, I) = Base.checkbounds(bc, I)
@inline function Base.getindex(
bc::NonExtrudedBroadcasted,
I::Union{Integer, CartesianIndex},
)
@boundscheck _checkbounds(bc, axes(bc), I) # is this really the only issue?
@inbounds _broadcast_getindex(bc, I)
end

# --- here, we define our own bounds checks
@inline function Base.checkbounds(bc::NonExtrudedBroadcasted, I::Integer)
# Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base
Base.checkbounds_indices(Bool, (Base.OneTo(n_dofs(bc)),), (I,)) || Base.throw_boundserror(bc, (I,))
end

import StaticArrays
to_tuple(t::Tuple) = t
to_tuple(t::NTuple{N, <: Base.OneTo}) where {N} = map(x->x.stop, t)
to_tuple(t::NTuple{N, <: StaticArrays.SOneTo}) where {N} = map(x->x.stop, t)
n_dofs(bc::NonExtrudedBroadcasted) = prod(to_tuple(axes(bc)))
# ---

Base.@propagate_inbounds _broadcast_getindex(
A::Union{Ref, AbstractArray{<:Any, 0}, Number},
I::Integer,
) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds _broadcast_getindex(
::Ref{Type{T}},
I::Integer,
) where {T} = T
# Tuples are statically known to be singleton or vector-like
Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I::Integer) = A[1]
Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I::Integer) = A[I[1]]
# Everything else falls back to dynamically dropping broadcasted indices based upon its axes
# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)]
Base.@propagate_inbounds _broadcast_getindex(A, I::Integer) = A[I]
Base.@propagate_inbounds function _broadcast_getindex(
bc::NonExtrudedBroadcasted{<:Any, <:Any, <:Any, <:Any},
I::Integer,
)
args = _getindex(bc.args, I)
return _broadcast_getindex_evalf(bc.f, args...)
end
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} =
f(args...) # not propagate_inbounds
Base.@propagate_inbounds _getindex(args::Tuple, I) =
(_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) =
(_broadcast_getindex(args[1], I),)
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()

@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes)
_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes
@inline _axes(bc::NonExtrudedBroadcasted, ::Nothing) = Base.Broadcast.combine_axes(bc.args...)
_axes(bc::NonExtrudedBroadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}, ::Nothing) = ()
@inline Base.axes(bc::NonExtrudedBroadcasted{<:Any, <:NTuple{N}}, d::Integer) where {N} =
d <= N ? axes(bc)[d] : OneTo(1)
Base.IndexStyle(::Type{<:NonExtrudedBroadcasted{<:Any, <:Tuple{Any}}}) = IndexLinear()
@inline _axes(::NonExtrudedBroadcasted, axes) = axes
@inline Base.eltype(bc::NonExtrudedBroadcasted) = Base.Broadcast.combine_axes(bc.args...)


# ============================================================

#! format: on
# Datalayouts
@propagate_inbounds function linear_getindex(
data::AbstractData{S},
I::Integer,
) where {S}
s_array = farray_size(data)
ss = StaticSize(s_array, field_dim(data))
@inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), I, ss)
end
@propagate_inbounds function linear_setindex!(
data::AbstractData{S},
val,
I::Integer,
) where {S}
s_array = farray_size(data)
ss = StaticSize(s_array, field_dim(data))
@inbounds set_struct_linear!(
parent(data),
convert(S, val),
Val(field_dim(data)),
I,
ss,
)
end

for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH) # Skip DataF, since we want that to MethodError.
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
linear_getindex(data, I)
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
linear_setindex!(data, val, I)
end
Loading

0 comments on commit 378bbcf

Please sign in to comment.