Skip to content

Commit

Permalink
Try to resolve style conflict and extend similar
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Feb 10, 2022
1 parent 9b9d8b2 commit c7a4dc8
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 25 deletions.
43 changes: 36 additions & 7 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,31 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
end

# broadcast
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted

struct StructArrayStyle{S,N} <: AbstractArrayStyle{N} end
# If `S` also track input's dimensionality, we'd better also update it.
StructArrayStyle{S,M}(::Val{N}) where {M,S<:AbstractArrayStyle{M},N} =
StructArrayStyle{typeof(S(Val(N))),N}()
StructArrayStyle{S,M}(::Val{N}) where {M,S,N} = StructArrayStyle{S,N}()

# Here we define the dimension tracking behaviour of StructArrayStyle
function StructArrayStyle{S,M}(::Val{N}) where {S,M,N}
if S <: AbstractArrayStyle{M}
return StructArrayStyle{typeof(S(Val(N))),N}()
end
return StructArrayStyle{S,N}()
end

_dimmax(a::Integer, b::Integer) = max(a, b)
_dimmax(::Type{Any}, ::Integer) = Any
_dimmax(::Integer ,::Type{Any}) = Any

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S,N}) where {S,N,M}
S′ = Broadcast.result_style(S(), b)
if S′ isa StructArrayStyle # avoid double wrap
return typeof(S′)(Val(_dimmax(N,M)))
end
StructArrayStyle{typeof(S′),_dimmax(N,M)}()
end

@inline combine_style_types(::Type{A}, args...) where A<:AbstractArray =
combine_style_types(BroadcastStyle(A), args...)
Expand All @@ -461,8 +479,19 @@ Base.@pure cst(::Type{SA}) where SA = combine_style_types(array_types(SA).parame

BroadcastStyle(::Type{SA}) where SA<:StructArray = StructArrayStyle{typeof(cst(SA)),ndims(SA)}()

Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S<:DefaultArrayStyle,N,ElType} =
isstructtype(ElType) ? similar(StructArray{ElType}, axes(bc)) : similar(Array{ElType}, axes(bc))
# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{<:StructArrayStyle{S}}, ::Type{ElType}) where {S,ElType}
bc′ = convert(Broadcasted{S}, bc)
if isstructtype(ElType)
return buildfromschema(T -> similar(bc′, T), ElType)
end
return similar(bc′, ElType)
end

# Unwrapper the style to recover the behaviour defined by style.
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle{S}}) where {S}
return copyto!(dest, convert(Broadcasted{S}, bc))
end

# for aliasing analysis during broadcast
Base.dataids(u::StructArray) = mapreduce(Base.dataids, (a, b) -> (a..., b...), values(components(u)), init=())
79 changes: 61 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -898,17 +898,25 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
@test t.b.d isa Array
end

struct MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
end
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
Base.getindex(A::MyArray, i::Int) = A.A[i]
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
Base.size(A::MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
MyArray{ElType}(undef, size(bc))
for S in (1, 2, 3)
MyArray = Symbol(:MyArray, S)
@eval begin
struct $MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
end
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
Base.getindex(A::$MyArray, i::Int) = A.A[i]
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
Base.size(A::$MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
end
end
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
MyArray1{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
MyArray2{ElType}(undef, size(bc))
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
Expand All @@ -926,24 +934,59 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
# used inside of broadcast but we also test it here explicitly
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
@test_throws MethodError s .+ s
# Make sure we can handle style with similar defined
# s1 and s2 has similar defined, but s3 not
# s2 are conflict with s1 and s3.
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))

function _test_similar(a, b)
flag = false
try
c = StructArray{ComplexF64}((a.re .+ b.re, a.im .+ b.im))
flag = true
catch
end
if flag
@test typeof(@inferred(a .+ b)) == typeof(c)
else
@test_throws MethodError a .+ b
end
end
for s in (s1,s2,s3), s′ in (s1,s2,s3)
_test_similar(s, s′)
end

# test for dimensionality track
s = s1
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, [1,2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, [1;;2]) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
@test Base.broadcasted(+, [1;;;2], s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}

a = StructArray([1;2+im])
b = StructArray([1;;2+im])
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

# issue #185
A = StructArray(randn(ComplexF64, 3, 3))
B = randn(ComplexF64, 3, 3)
c = StructArray(randn(ComplexF64, 3))
@test (A .= B .* c) === A

# ambiguity check (can we do this better?)
function _test(a, b)
if a isa StructArray || b isa StructArray
d = @inferred a .+ b
@test d == collect(a) .+ collect(b)
@test d isa StructArray
end
end
testset = StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), (@SArray [1 2])
for aa in testset, bb in testset
_test(aa, bb)
end
a = StructArray([1;2+im])
b = StructArray([1 2+im])
@test @inferred(a .+ b .+ a .* a' .+ (1,2) .+ (1:2) .- b') isa StructArray
end

@testset "staticarrays" begin
Expand Down

0 comments on commit c7a4dc8

Please sign in to comment.