diff --git a/src/interface.jl b/src/interface.jl index 461e2d49..d82aa4f6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -49,4 +49,6 @@ function createinstance(::Type{T}, args...) where {T} isconcretetype(T) ? bypass_constructor(T, args) : T(args...) end -createinstance(::Type{T}, args...) where {T<:Tup} = T(args) \ No newline at end of file +createinstance(::Type{T}, args...) where {T<:Tup} = T(args) + +createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index 5f6fafd2..a796b578 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -29,16 +29,30 @@ StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{ StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) # Broadcast overload -using StaticArraysCore: StaticArrayStyle -import StaticArraysCore: Size, is_staticarray_like, similar_type +using StaticArraysCore: StaticArrayStyle, similar_type StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N} function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} - bc′ = Broadcast.instantiate(convert(Broadcasted{StaticArrayStyle{M}}, bc)) + bc′ = Broadcast.instantiate(replace_structarray(bc)) return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) end -function Broadcast._axes(bc::Broadcasted{StructStaticArrayStyle{M}}, ::Nothing) where {M} - return Broadcast._axes(convert(Broadcasted{StaticArrayStyle{M}}, bc), nothing) +# This looks costy, but compiler should be able to optimize them away +Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc)) + +to_staticstyle(@nospecialize(x::Type)) = x +to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N} +function replace_structarray(bc::Broadcasted{Style}) where {Style} + args = replace_structarray_args(bc.args) + return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing) +end +function replace_structarray(A::StructArray) + f = createinstance(eltype(A)) + args = Tuple(components(A)) + return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing) end +replace_structarray(@nospecialize(A)) = A + +replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...) +replace_structarray_args(::Tuple{}) = () # StaticArrayStyle has no similar defined. # Overload `Base.copy` instead. @@ -48,7 +62,7 @@ end isnonemptystructtype(ET) || return sa elements = Tuple(sa) arrs = ntuple(Val(fieldcount(ET))) do i - similar_type(sa, fieldtype(ET, i), Size(sa))(_getfields(elements, i)) + similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) end return StructArray{ET}(arrs) end @@ -60,9 +74,3 @@ end return map(Base.Fix2(getfield, i), x) end end - -Size(::Type{SA}) where {SA<:StructArray} = Size(fieldtype(array_types(SA), 1)) -is_staticarray_like(x::StructArray) = any(is_staticarray_like, components(x)) -function similar_type(::Type{SA}, ::Type{T}, s::Size{S}) where {SA<:StructArray, T, S} - return similar_type(fieldtype(array_types(SA), 1), T, s) -end