Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize StructArray's broadcast. #215

Merged
merged 9 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ version = "0.6.13"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
StaticArraysCore = "1.1"
StaticArrays = "1.5.4"
StaticArraysCore = "1.3"
StaticArrays = "1.5.6"
GPUArraysCore = "0.1.2"
Tables = "1"
julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
10 changes: 10 additions & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,14 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end

end # module
8 changes: 7 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,10 @@ function createinstance(::Type{T}, args...) where {T}
isconcretetype(T) ? bypass_constructor(T, args) : T(args...)
end

createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)

struct Instantiator{T} end

Instantiator(::Type{T}) where {T} = Instantiator{T}()

(::Instantiator{T})(args...) where {T} = createinstance(T, args...)
66 changes: 65 additions & 1 deletion src/staticarrays_support.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import StaticArraysCore: StaticArray, FieldArray, tuple_prod
using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
Expand Down Expand Up @@ -27,3 +27,67 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArraysCore: StaticArrayStyle, similar_type
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
bc′ = Broadcast.instantiate(replace_structarray(bc))
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
end
# This looks costly, but the compiler should be able to optimize them away
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))

to_staticstyle(@nospecialize(x::Type)) = x
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}

"""
replace_structarray(bc::Broadcasted)

An internal function transforms the `Broadcasted` with `StructArray` into
an equivalent one without it. This is not a must if the root `BroadcastStyle`
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
args = replace_structarray_args(bc.args)
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
end
function replace_structarray(A::StructArray)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
f = Instantiator(eltype(A))
args = Tuple(components(A))
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

# StaticArrayStyle has no similar defined.
# Overload `Base.copy` instead.
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
elements = Tuple(sa)
@static if VERSION >= v"1.7"
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
end
else
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
__fieldtype = _fieldtype(ET)
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
end
end
return StructArray{ET}(arrs)
end

@inline function _getfields(x::Tuple, i::Int)
if @generated
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
else
return map(Base.Fix2(getfield, i), x)
end
end
28 changes: 24 additions & 4 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
end

# broadcast
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown

struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end

Expand All @@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
return StructArrayStyle{T, N}()
end

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
N′ = M === Any || N === Any ? Any : max(M, N)
S′ = Broadcast.result_style(S(), b)
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
N5N3 marked this conversation as resolved.
Show resolved Hide resolved

@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()

function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType}
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
return similar(ContainerType, axes(bc))
# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
bc′ = convert(Broadcasted{S}, bc)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
end

# Unwrapper to recover the behaviour defined by parent style.
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
return copyto!(dest, convert(Broadcasted{S}, bc))
end

@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
return Broadcast.materialize!(S(), dest, bc)
end

# for aliasing analysis during broadcast
Expand Down
119 changes: 102 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings
using TypedTables: Table
using DataAPI: refarray, refvalue
using Adapt: adapt, Adapt
using JLArrays
using Test

using Documenter: doctest
Expand Down Expand Up @@ -1100,17 +1101,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
@test t.b.d isa Array
end

struct MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
# 1. `MyArray1` and `MyArray1` have `similar` defined.
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
# 2. `MyArray3` has no `similar` defined.
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
# 3. Their resolved style could be summaryized as (`-` means conflict)
# | MyArray1 | MyArray2 | MyArray3 | Array
# -------------------------------------------------------------
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1
# MyArray2 | - | MyArray2 | - | MyArray2
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3
# Array | MyArray1 | Array | MyArray3 | Array

for S in (1, 2, 3)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
MyArray = Symbol(:MyArray, S)
@eval begin
struct $MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
end
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
Base.getindex(A::$MyArray, i::Int) = A.A[i]
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
Base.size(A::$MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
end
end
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
Base.getindex(A::MyArray, i::Int) = A.A[i]
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
Base.size(A::MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
MyArray{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
MyArray1{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
MyArray2{ElType}(undef, size(bc))
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
Expand All @@ -1128,19 +1151,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
# used inside of broadcast but we also test it here explicitly
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
@test_throws MethodError s .+ s
# Make sure we can handle style with similar defined
# And we can handle most conflicts
# `s1` and `s2` have similar defined, but `s3` does not
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
s4 = StructArray{ComplexF64}((rand(2), rand(2)))

function _test_similar(a, b, c)
try
d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im))
@test typeof(a .+ b .- c) == typeof(d)
catch
@test_throws MethodError a .+ b .- c
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This had escaped me before, but I'm wondering: could it be possible to be explicit here on which it is (correct result or method error) based on the input types?

Ideally one would want to explicitly test what one is getting, so I would suggest to remove the helper function _test_similar and just write something like

if s2 in (s, s′, s″) && (s1 in (s, s′, s″) || s3 in (s, s′, s″))
    # test method error
else
    # test correct result
end

in the loop body.

(I'm not sure whether that's the correct criterion.)

end
for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4)
_test_similar(s, s′, s″)
end

# test for dimensionality track
s = s1
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}

a = StructArray([1;2+im])
b = StructArray([1;;2+im])
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
@test a .+ Any[1] isa StructArray
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

# issue #185
A = StructArray(randn(ComplexF64, 3, 3))
Expand All @@ -1155,6 +1193,53 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El

@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)]
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3]
@test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)]
@test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3]

@testset "ambiguity check" begin
function _test(a, b, c)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
if a isa StructArray || b isa StructArray || c isa StructArray
d = @inferred a .+ b .- c
@test d == collect(a) .+ collect(b) .- collect(c)
@test d isa StructArray
end
end
testset = Any[StructArray([1;2+im]),
1:2,
(1,2),
StructArray(@SArray [1 1+2im]),
(@SArray [1 2])
]
for aa in testset, bb in testset, cc in testset
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
_test(aa, bb, cc)
end
end

@testset "StructStaticArray" begin
bclog(s) = log.(s)
test_allocated(f, s) = @test (@allocated f(s)) == 0
a = @SMatrix [float(i) for i in 1:10, j in 1:10]
b = @SMatrix [0. for i in 1:10, j in 1:10]
s = StructArray{ComplexF64}((a , b))
@test (@inferred bclog(s)) isa typeof(s)
test_allocated(bclog, s)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
@test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix
bc = Base.broadcasted(+, s, s);
bc = Base.broadcasted(+, bc, bc, s);
@test @inferred(Broadcast.axes(bc)) === axes(s)
end

@testset "StructJLArray" begin
bcabs(a) = abs.(a)
bcmul2(a) = 2 .* a
a = StructArray(randn(ComplexF32, 10, 10))
sa = jl(a)
backend = StructArrays.GPUArraysCore.backend
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im)
@test collect(@inferred(bcabs(sa))) == bcabs(a)
@test @inferred(bcmul2(sa)) isa StructArray
@test (sa .+= 1) isa StructArray
end
end

@testset "map" begin
Expand Down