diff --git a/src/structarray.jl b/src/structarray.jl index 28abdafe..c1055331 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -443,13 +443,31 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end -# If `S` also track input's dimensionality, we'd better also update it. -StructArrayStyle{S,M}(::Val{N}) where {M,S<:AbstractArrayStyle{M},N} = - StructArrayStyle{typeof(S(Val(N))),N}() -StructArrayStyle{S,M}(::Val{N}) where {M,S,N} = StructArrayStyle{S,N}() + +# Here we define the dimension tracking behaviour of StructArrayStyle +function StructArrayStyle{S,M}(::Val{N}) where {S,M,N} + if S <: AbstractArrayStyle{M} + return StructArrayStyle{typeof(S(Val(N))),N}() + end + return StructArrayStyle{S,N}() +end + +_dimmax(a::Integer, b::Integer) = max(a, b) +_dimmax(::Type{Any}, ::Integer) = Any +_dimmax(::Integer ,::Type{Any}) = Any + +# StructArrayStyle is a wrapped style. +# Here we try our best to resolve style conflict. +function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) where {S,N,M} + S′ = Broadcast.result_style(S(), b) + if S′ isa StructArrayStyle # avoid double wrap + return typeof(S′)(Val(_dimmax(N,M))) + end + StructArrayStyle{typeof(S′),_dimmax(N,M)}() +end @inline combine_style_types(::Type{A}, args...) where A<:AbstractArray = combine_style_types(BroadcastStyle(A), args...) @@ -461,8 +479,19 @@ Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parame BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}() -Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} = - isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc)) +# Here we use `similar` defined for `S` to build the dest Array. +function Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S,ElType} + bc′ = convert(Broadcasted{S}, bc) + if isstructtype(ElType) + return buildfromschema(T -> similar(bc′, T), ElType) + end + return similar(bc′, ElType) +end + +# Unwrapper the style to recover the behaviour defined by style. +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle{S}}) where {S} + return copyto!(dest, convert(Broadcasted{S}, bc)) +end # for aliasing analysis during broadcast Base.dataids(u::StructArray) = mapreduce(Base.dataids, (a, b) -> (a..., b...), values(components(u)), init=()) diff --git a/test/runtests.jl b/test/runtests.jl index 20412dc5..a4f19ce4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -898,17 +898,25 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.d isa Array end -struct MyArray{T,N} <: AbstractArray{T,N} - A::Array{T,N} -end -MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz)) -Base.IndexStyle(::Type{<:MyArray}) = IndexLinear() -Base.getindex(A::MyArray, i::Int) = A.A[i] -Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val -Base.size(A::MyArray) = Base.size(A.A) -Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}() -Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType = - MyArray{ElType}(undef, size(bc)) +for S in (1, 2, 3) + MyArray = Symbol(:MyArray, S) + @eval begin + struct $MyArray{T,N} <: AbstractArray{T,N} + A::Array{T,N} + end + $MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz)) + Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear() + Base.getindex(A::$MyArray, i::Int) = A.A[i] + Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val + Base.size(A::$MyArray) = Base.size(A.A) + Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}() + end +end +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType = + MyArray1{ElType}(undef, size(bc)) +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType = + MyArray2{ElType}(undef, size(bc)) +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}() @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) @@ -926,24 +934,59 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El # used inside of broadcast but we also test it here explicitly @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) - s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2)))) - @test_throws MethodError s .+ s + # Make sure we can handle style with similar defined + # s1 and s2 has similar defined, but s3 not + # s2 are conflict with s1 and s3. + s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) + s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) + s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) + + function _test_similar(a, b) + flag = false + try + c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im)) + flag = true + catch + end + if flag + @test typeof(@inferred(a .+ b)) == typeof(c) + else + @test_throws MethodError a .+ b + end + end + for s in (s1,s2,s3), s′ in (s1,s2,s3) + _test_similar(s, s′) + end # test for dimensionality track + s = s1 @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} - @test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} + @test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} @test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} - - a = StructArray([1;2+im]) - b = StructArray([1;;2+im]) - @test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) + @test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} # issue #185 A = StructArray(randn(ComplexF64, 3, 3)) B = randn(ComplexF64, 3, 3) c = StructArray(randn(ComplexF64, 3)) @test (A .= B .* c) === A + + # ambiguity check (can we do this better?) + function _test(a, b) + if a isa StructArray || b isa StructArray + d = @inferred a .+ b + @test d == collect(a) .+ collect(b) + @test d isa StructArray + end + end + testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2]) + for aa in testset, bb in testset + _test(aa, bb) + end + a = StructArray([1;2+im]) + b = StructArray([1 2+im]) + @test @inferred(a .+ b .+ a .* a' .+ (1,2) .+ (1:2) .- b') isa StructArray end @testset "staticarrays" begin