diff --git a/src/StructArrays.jl b/src/StructArrays.jl index f7431992..8f1a198c 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -30,8 +30,8 @@ import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) # for GPU broadcast -import GPUArraysCore: backend -function backend(::Type{T}) where {T<:StructArray} +import GPUArraysCore +function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} backs = map(backend, fieldtypes(array_types(T))) all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!") return backs[1] diff --git a/src/interface.jl b/src/interface.jl index d82aa4f6..87057ae5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -52,3 +52,9 @@ end createinstance(::Type{T}, args...) where {T<:Tup} = T(args) createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...) + +struct Instantiator{T} end + +Instantiator(::Type{T}) where {T} = Instantiator{T}() + +(::Instantiator{T})(args...) where {T} = createinstance(T, args...) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index 09a93a5d..1f898f82 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -35,23 +35,32 @@ function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where bc′ = Broadcast.instantiate(replace_structarray(bc)) return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) end -# This looks costy, but compiler should be able to optimize them away +# This looks costly, but the compiler should be able to optimize them away Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc)) to_staticstyle(@nospecialize(x::Type)) = x to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N} + +""" + replace_structarray(bc::Broadcasted) + +An internal function transforms the `Broadcasted` with `StructArray` into +an equivalent one without it. This is not a must if the root `BroadcastStyle` +supports `AbstractArray`. But some `BroadcastStyle` limits the input array types, +e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`. +""" function replace_structarray(bc::Broadcasted{Style}) where {Style} args = replace_structarray_args(bc.args) return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing) end function replace_structarray(A::StructArray) - f = createinstance(eltype(A)) + f = Instantiator(eltype(A)) args = Tuple(components(A)) return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing) end replace_structarray(@nospecialize(A)) = A -replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...) +replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...) replace_structarray_args(::Tuple{}) = () # StaticArrayStyle has no similar defined. diff --git a/test/runtests.jl b/test/runtests.jl index 04725f40..e7d1a346 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1091,6 +1091,18 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.d isa Array end +# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s. +# 1. `MyArray1` and `MyArray1` has `similar` defined thus use the default broadcast routine. +# 2. `MyArray3` has no `similar` defined. +# We use it to simulate `BroadcastStyle` overloading `Base.copy` rather than `Base.copyto!` +# 3. Their resolved style could be summaryized as (`-` means conflict) +# | MyArray1 | MyArray2 | MyArray3 | Array +# ------------------------------------------------------------- +# MyArray1 | MyArray1 | - | MyArray1 | MyArray1 +# MyArray2 | - | MyArray2 | - | MyArray2 +# MyArray3 | MyArray1 | - | MyArray3 | MyArray3 +# Array | MyArray1 | Array | MyArray3 | Array + for S in (1, 2, 3) MyArray = Symbol(:MyArray, S) @eval begin @@ -1129,9 +1141,9 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) # Make sure we can handle style with similar defined - # And we can handle most conflict - # s1 and s2 has similar defined, but s3 not - # s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle) + # And we can handle most conflicts + # `s1` and `s2` have similar defined, but `s3` does not + # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle` s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))