Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize StructArray's broadcast. #215

Merged
merged 9 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,29 @@ version = "0.6.13"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
StaticArraysCore = "1.1"
StaticArrays = "1.5.4"
GPUArraysCore = "0.1.2"
StaticArrays = "1.5.6"
StaticArraysCore = "1.3"
Tables = "1"
julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "Random"]
10 changes: 10 additions & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,14 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end

end # module
8 changes: 7 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,10 @@ function createinstance(::Type{T}, args...) where {T}
isconcretetype(T) ? bypass_constructor(T, args) : T(args...)
end

createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)

struct Instantiator{T} end

Instantiator(::Type{T}) where {T} = Instantiator{T}()

(::Instantiator{T})(args...) where {T} = createinstance(T, args...)
66 changes: 65 additions & 1 deletion src/staticarrays_support.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import StaticArraysCore: StaticArray, FieldArray, tuple_prod
using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
Expand Down Expand Up @@ -27,3 +27,67 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArraysCore: StaticArrayStyle, similar_type
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
bc′ = Broadcast.instantiate(replace_structarray(bc))
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
end
# This looks costly, but the compiler should be able to optimize them away
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))

to_staticstyle(@nospecialize(x::Type)) = x
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}

"""
replace_structarray(bc::Broadcasted)

An internal function transforms the `Broadcasted` with `StructArray` into
an equivalent one without it. This is not a must if the root `BroadcastStyle`
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
args = replace_structarray_args(bc.args)
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
end
function replace_structarray(A::StructArray)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
f = Instantiator(eltype(A))
args = Tuple(components(A))
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

# StaticArrayStyle has no similar defined.
# Overload `Base.copy` instead.
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
elements = Tuple(sa)
@static if VERSION >= v"1.7"
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
end
else
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
__fieldtype = _fieldtype(ET)
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
end
end
return StructArray{ET}(arrs)
end

@inline function _getfields(x::Tuple, i::Int)
if @generated
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
else
return map(Base.Fix2(getfield, i), x)
end
end
28 changes: 24 additions & 4 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
end

# broadcast
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown

struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end

Expand All @@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
return StructArrayStyle{T, N}()
end

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
N′ = M === Any || N === Any ? Any : max(M, N)
S′ = Broadcast.result_style(S(), b)
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
N5N3 marked this conversation as resolved.
Show resolved Hide resolved

@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()

function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType}
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
return similar(ContainerType, axes(bc))
# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
bc′ = convert(Broadcasted{S}, bc)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
end

# Unwrapper to recover the behaviour defined by parent style.
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
return copyto!(dest, convert(Broadcasted{S}, bc))
end

@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
return Broadcast.materialize!(S(), dest, bc)
end

# for aliasing analysis during broadcast
Expand Down
133 changes: 116 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import Tables, PooledArrays, WeakRefStrings
using TypedTables: Table
using DataAPI: refarray, refvalue
using Adapt: adapt, Adapt
using JLArrays
using Random
using Test

using Documenter: doctest
Expand Down Expand Up @@ -1100,17 +1102,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
@test t.b.d isa Array
end

struct MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
# 1. `MyArray1` and `MyArray1` have `similar` defined.
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
# 2. `MyArray3` has no `similar` defined.
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
# 3. Their resolved style could be summaryized as (`-` means conflict)
# | MyArray1 | MyArray2 | MyArray3 | Array
# -------------------------------------------------------------
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1
# MyArray2 | - | MyArray2 | - | MyArray2
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3
# Array | MyArray1 | Array | MyArray3 | Array

for S in (1, 2, 3)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
MyArray = Symbol(:MyArray, S)
@eval begin
struct $MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
end
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
Base.getindex(A::$MyArray, i::Int) = A.A[i]
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
Base.size(A::$MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
end
end
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
Base.getindex(A::MyArray, i::Int) = A.A[i]
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
Base.size(A::MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
MyArray{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
MyArray1{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
MyArray2{ElType}(undef, size(bc))
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
Expand All @@ -1128,19 +1152,44 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
# used inside of broadcast but we also test it here explicitly
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
@test_throws MethodError s .+ s

@testset "style conflict check" begin
using StructArrays: StructArrayStyle
# Make sure we can handle style with similar defined
# And we can handle most conflicts
# `s1` and `s2` have similar defined, but `s3` does not
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
test_set = Any[s1, s2, s3, s4]
tested_style = Any[]
dotaddadd((a, b, c),) = @. a + b + c
for is in Iterators.product(randperm(4), randperm(4), randperm(4))
as = map(i -> test_set[i], is)
ares = map(a->a.re, as)
aims = map(a->a.im, as)
style = Broadcast.combine_styles(ares...)
if !(style in tested_style)
push!(tested_style, style)
if style isa Broadcast.ArrayStyle{MyArray3}
@test_throws MethodError dotaddadd(as)
else
d = StructArray{ComplexF64}((dotaddadd(ares), dotaddadd(aims)))
@test @inferred(dotaddadd(as))::typeof(d) == d
end
end
end
@test length(tested_style) == 5
end
# test for dimensionality track
s = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}

a = StructArray([1;2+im])
b = StructArray([1;;2+im])
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
@test a .+ Any[1] isa StructArray
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

# issue #185
A = StructArray(randn(ComplexF64, 3, 3))
Expand All @@ -1155,6 +1204,56 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El

@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)]
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3]
@test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)]
@test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3]

@testset "ambiguity check" begin
test_set = Any[StructArray([1;2+im]),
1:2,
(1,2),
StructArray(@SArray [1;1+2im]),
(@SArray [1 2]),
1]
tested_style = StructArrayStyle[]
dotaddsub((a, b, c),) = @. a + b - c
for is in Iterators.product(randperm(6), randperm(6), randperm(6))
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
as = map(i -> test_set[i], is)
if any(a -> a isa StructArray, as)
style = Broadcast.combine_styles(as...)
if !(style in tested_style)
push!(tested_style, style)
@test @inferred(dotaddsub(as))::StructArray == dotaddsub(map(collect, as))
end
end
end
@test length(tested_style) == 4
end

@testset "StructStaticArray" begin
bclog(s) = log.(s)
test_allocated(f, s) = @test (@allocated f(s)) == 0
a = @SMatrix [float(i) for i in 1:10, j in 1:10]
b = @SMatrix [0. for i in 1:10, j in 1:10]
s = StructArray{ComplexF64}((a , b))
@test (@inferred bclog(s)) isa typeof(s)
test_allocated(bclog, s)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
@test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix
bc = Base.broadcasted(+, s, s);
bc = Base.broadcasted(+, bc, bc, s);
@test @inferred(Broadcast.axes(bc)) === axes(s)
end

@testset "StructJLArray" begin
bcabs(a) = abs.(a)
bcmul2(a) = 2 .* a
a = StructArray(randn(ComplexF32, 10, 10))
sa = jl(a)
backend = StructArrays.GPUArraysCore.backend
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im)
@test collect(@inferred(bcabs(sa))) == bcabs(a)
@test @inferred(bcmul2(sa)) isa StructArray
@test (sa .+= 1) isa StructArray
end
end

@testset "map" begin
Expand Down