diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index 0c1357db..4012c54c 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -1,4 +1,6 @@ -import StaticArrays: StaticArray, FieldArray, tuple_prod, StaticArrayStyle +using StaticArrays: StaticArrays, StaticArray, FieldArray, tuple_prod, StaticArrayStyle +import StaticArrays: Size +import Base.Broadcast: instantiate """ StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} @@ -28,10 +30,24 @@ end StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) +@static if isdefined(StaticArrays, :static_combine_axes) # StaticArrayStyle has no similar defined. -# Convert to `DefaultArrayStyle` to return a sized (Struct)Array. -# TODO: return a StaticArray? -function Base.copy(bc::Broadcasted{StructArrayStyle{StaticArrayStyle{N},N}}) where {N} - bc′ = convert(Broadcasted{StructArrayStyle{Broadcast.DefaultArrayStyle{N},N}}, bc) +# Convert to `StaticArrayStyle` to return a StaticArray instead. +StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N} +@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} + bc′ = convert(Broadcasted{StaticArrayStyle{M}}, bc) return copy(bc′) end +function instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} + bc′ = instantiate(convert(Broadcasted{StaticArrayStyle{M}}, bc)) + return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) +end +function Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) + return StaticArrays.static_combine_axes(bc.args...) +end +Size(::Type{SA}) where {SA<:StructArray} = Size(fieldtype(fieldtype(SA, 1), 1)) +StaticArrays.isstatic(::SA) where {SA<:StructArray} = cst(SA) isa StaticArrayStyle +function StaticArrays.similar_type(::Type{SA}, ::Type{T}, s::Size{S}) where {SA<:StructArray, T, S} + return StaticArrays.similar_type(fieldtype(fieldtype(SA, 1), 1), T, s) +end +end diff --git a/test/runtests.jl b/test/runtests.jl index 13ad213e..04165ce3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -977,26 +977,35 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test @inferred(broadcast(el -> el.a, v)) == ["s1", "s2"] # ambiguity check (can we do this better?) - function _test(a, b, c) + function _test(a, b, c, T = StructArray) if a isa StructArray || b isa StructArray || c isa StructArray d = @inferred a .+ b .- c @test d == collect(a) .+ collect(b) .- collect(c) - @test d isa StructArray + @test d isa T end end - testset = (StructArray([1;2+im]), + testset = Any[StructArray([1;2+im]), StructArray([1 2+im]), 1:2, (1,2), - (@SArray [1 2]), - StructArray(@SArray [1 1+2im])) + (@SArray [1 2])] for aa in testset, bb in testset, cc in testset _test(aa, bb, cc) end + if isdefined(StaticArrays, :static_combine_axes) + testset = Any[StructArray(@SArray [1 1+2im]), (1,2), StructArray(@SArray [1;1+2im])] + for aa in testset, bb in testset, cc in testset + _test(aa, bb, cc, StaticArray) + end + end +end - a = @SArray randn(3,3); - b = StructArray{ComplexF64}((a,a)) - @test a[:,1] .+ b isa StructArray && (a[:,1] .+ b).re isa SizedMatrix +function struct_static_allocated_test() + s = StructArray{ComplexF64}((SVector(1., 2., 3.), SVector(0., 0., 0.))) + return broadcast(log, s) +end +if isdefined(StaticArrays, :static_combine_axes) + @test (@allocated struct_static_allocated_test()) === 0 end @testset "staticarrays" begin