Skip to content

Commit

Permalink
Convert to StaticArrayStyle
Browse files Browse the repository at this point in the history
We first call broadcast from `StaticArrays` then split the output.
This should has no extra runtime overhead. But some type info might missing because the eltype change. I think there's no better ways as we don't want to depend on the full  `StaticArrays`.

We don't overloading `Size` and `similar_type` at present.
as they are only used for `broadcast`.
With this, we can move much less code to `StaticArraysCore`.

The only downside is that SizedArray would be allocated twice. That's not idea, but we can't do any better if we don't depend on StaticArray or copy a lot of code from there.
  • Loading branch information
N5N3 committed Aug 24, 2022
1 parent b99776e commit 5db3da5
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ function createinstance(::Type{T}, args...) where {T}
isconcretetype(T) ? bypass_constructor(T, args) : T(args...)
end

createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)

createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...)
49 changes: 48 additions & 1 deletion src/staticarrays_support.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import StaticArraysCore: StaticArray, FieldArray, tuple_prod
using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
Expand Down Expand Up @@ -27,3 +27,50 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArraysCore: StaticArrayStyle, similar_type
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
bc′ = Broadcast.instantiate(replace_structarray(bc))
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
end
# This looks costy, but compiler should be able to optimize them away
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))

to_staticstyle(@nospecialize(x::Type)) = x
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}
function replace_structarray(bc::Broadcasted{Style}) where {Style}
args = replace_structarray_args(bc.args)
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
end
function replace_structarray(A::StructArray)
f = createinstance(eltype(A))
args = Tuple(components(A))
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...)
replace_structarray_args(::Tuple{}) = ()

# StaticArrayStyle has no similar defined.
# Overload `Base.copy` instead.
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
elements = Tuple(sa)
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
end
return StructArray{ET}(arrs)
end

@inline function _getfields(x::Tuple, i::Int)
if @generated
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
else
return map(Base.Fix2(getfield, i), x)
end
end
14 changes: 13 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1160,12 +1160,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
end
testset = Any[StructArray([1;2+im]),
1:2,
(1,2),
(1,2),
StructArray(@SArray [1 1+2im]),
(@SArray [1 2])
]
for aa in testset, bb in testset, cc in testset
_test(aa, bb, cc)
end
end

@testset "StructStaticArray" begin
bclog(s) = log.(s)
test_allocated(f, s) = @test (@allocated f(s)) == 0
a = @SMatrix [float(i) for i in 1:10, j in 1:10]
b = @SMatrix [0. for i in 1:10, j in 1:10]
s = StructArray{ComplexF64}((a , b))
@test (@inferred bclog(s)) isa typeof(s)
test_allocated(bclog, s)
end
end

@testset "map" begin
Expand Down

0 comments on commit 5db3da5

Please sign in to comment.