diff --git a/Project.toml b/Project.toml index d8049eef..0d8880ac 100644 --- a/Project.toml +++ b/Project.toml @@ -5,19 +5,22 @@ version = "0.6.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Adapt = "1, 2, 3" DataAPI = "1" -StaticArraysCore = "1.1" -StaticArrays = "1.5.4" +GPUArraysCore = "0.1.2" +StaticArrays = "1.5.6" +StaticArraysCore = "1.3" Tables = "1" julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 6fed453e..27e234d5 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -29,4 +29,14 @@ end import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) +# for GPU broadcast +import GPUArraysCore +function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} + backends = map_params(GPUArraysCore.backend, array_types(T)) + backend, others = backends[1], tail(backends) + isconsistent = mapfoldl(isequal(backend), &, others; init=true) + isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) + return backend +end + end # module diff --git a/src/interface.jl b/src/interface.jl index 461e2d49..010fc2f9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -49,4 +49,10 @@ function createinstance(::Type{T}, args...) where {T} isconcretetype(T) ? bypass_constructor(T, args) : T(args...) end -createinstance(::Type{T}, args...) where {T<:Tup} = T(args) \ No newline at end of file +createinstance(::Type{T}, args...) where {T<:Tup} = T(args) + +struct Instantiator{T} end + +Instantiator(::Type{T}) where {T} = Instantiator{T}() + +(::Instantiator{T})(args...) where {T} = createinstance(T, args...) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index 3fa9af98..1f898f82 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -1,4 +1,4 @@ -import StaticArraysCore: StaticArray, FieldArray, tuple_prod +using StaticArraysCore: StaticArray, FieldArray, tuple_prod """ StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} @@ -27,3 +27,67 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i) end StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) + +# Broadcast overload +using StaticArraysCore: StaticArrayStyle, similar_type +StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N} +function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} + bc′ = Broadcast.instantiate(replace_structarray(bc)) + return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) +end +# This looks costly, but the compiler should be able to optimize them away +Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc)) + +to_staticstyle(@nospecialize(x::Type)) = x +to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N} + +""" + replace_structarray(bc::Broadcasted) + +An internal function transforms the `Broadcasted` with `StructArray` into +an equivalent one without it. This is not a must if the root `BroadcastStyle` +supports `AbstractArray`. But some `BroadcastStyle` limits the input array types, +e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`. +""" +function replace_structarray(bc::Broadcasted{Style}) where {Style} + args = replace_structarray_args(bc.args) + return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing) +end +function replace_structarray(A::StructArray) + f = Instantiator(eltype(A)) + args = Tuple(components(A)) + return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing) +end +replace_structarray(@nospecialize(A)) = A + +replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...) +replace_structarray_args(::Tuple{}) = () + +# StaticArrayStyle has no similar defined. +# Overload `Base.copy` instead. +@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} + sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc)) + ET = eltype(sa) + isnonemptystructtype(ET) || return sa + elements = Tuple(sa) + @static if VERSION >= v"1.7" + arrs = ntuple(Val(fieldcount(ET))) do i + similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) + end + else + _fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i) + __fieldtype = _fieldtype(ET) + arrs = ntuple(Val(fieldcount(ET))) do i + similar_type(sa, __fieldtype(i))(_getfields(elements, i)) + end + end + return StructArray{ET}(arrs) +end + +@inline function _getfields(x::Tuple, i::Int) + if @generated + return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...) + else + return map(Base.Fix2(getfield, i), x) + end +end diff --git a/src/structarray.jl b/src/structarray.jl index d4bf529f..14c22e3f 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end @@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N} return StructArrayStyle{T, N}() end +# StructArrayStyle is a wrapped style. +# Here we try our best to resolve style conflict. +function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M} + N′ = M === Any || N === Any ? Any : max(M, N) + S′ = Broadcast.result_style(S(), b) + return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}() +end +BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() + @inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(BroadcastStyle(A), args...) @inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...) +combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...) BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}() -function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType} - ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType} - return similar(ContainerType, axes(bc)) +# Here we use `similar` defined for `S` to build the dest Array. +function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType} + bc′ = convert(Broadcasted{S}, bc) + return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType) +end + +# Unwrapper to recover the behaviour defined by parent style. +@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N} + return copyto!(dest, convert(Broadcasted{S}, bc)) +end + +@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S} + return Broadcast.materialize!(S(), dest, bc) end # for aliasing analysis during broadcast diff --git a/test/runtests.jl b/test/runtests.jl index 4693ca1b..9ac56f6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings using TypedTables: Table using DataAPI: refarray, refvalue using Adapt: adapt, Adapt +using JLArrays using Test using Documenter: doctest @@ -1100,17 +1101,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.d isa Array end -struct MyArray{T,N} <: AbstractArray{T,N} - A::Array{T,N} +# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s. +# 1. `MyArray1` and `MyArray1` have `similar` defined. +# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`. +# 2. `MyArray3` has no `similar` defined. +# We use it to simulate `BroadcastStyle` overloading `Base.copy`. +# 3. Their resolved style could be summaryized as (`-` means conflict) +# | MyArray1 | MyArray2 | MyArray3 | Array +# ------------------------------------------------------------- +# MyArray1 | MyArray1 | - | MyArray1 | MyArray1 +# MyArray2 | - | MyArray2 | - | MyArray2 +# MyArray3 | MyArray1 | - | MyArray3 | MyArray3 +# Array | MyArray1 | Array | MyArray3 | Array + +for S in (1, 2, 3) + MyArray = Symbol(:MyArray, S) + @eval begin + struct $MyArray{T,N} <: AbstractArray{T,N} + A::Array{T,N} + end + $MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz)) + Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear() + Base.getindex(A::$MyArray, i::Int) = A.A[i] + Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val + Base.size(A::$MyArray) = Base.size(A.A) + Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}() + end end -MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz)) -Base.IndexStyle(::Type{<:MyArray}) = IndexLinear() -Base.getindex(A::MyArray, i::Int) = A.A[i] -Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val -Base.size(A::MyArray) = Base.size(A.A) -Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}() -Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType = - MyArray{ElType}(undef, size(bc)) +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType = + MyArray1{ElType}(undef, size(bc)) +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType = + MyArray2{ElType}(undef, size(bc)) +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}() +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) @@ -1128,19 +1151,44 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El # used inside of broadcast but we also test it here explicitly @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) - s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2)))) - @test_throws MethodError s .+ s + @testset "style conflict check" begin + using StructArrays: StructArrayStyle + # Make sure we can handle style with similar defined + # And we can handle most conflicts + # `s1` and `s2` have similar defined, but `s3` does not + # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle` + s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) + s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) + s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) + s4 = StructArray{ComplexF64}((rand(2), rand(2))) + test_set = Any[s1, s2, s3, s4] + tested_style = Any[] + dotaddadd((a, b, c),) = @. a + b + c + for as in Iterators.product(test_set, test_set, test_set) + ares = map(a->a.re, as) + aims = map(a->a.im, as) + style = Broadcast.combine_styles(ares...) + @test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}() + if !(style in tested_style) + push!(tested_style, style) + if style isa Broadcast.ArrayStyle{MyArray3} + @test_throws MethodError dotaddadd(as) + else + d = StructArray{ComplexF64}((dotaddadd(ares), dotaddadd(aims))) + @test @inferred(dotaddadd(as))::typeof(d) == d + end + end + end + @test length(tested_style) == 5 + end # test for dimensionality track + s = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} @test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} - - a = StructArray([1;2+im]) - b = StructArray([1;;2+im]) - @test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) - @test a .+ Any[1] isa StructArray + @test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} # issue #185 A = StructArray(randn(ComplexF64, 3, 3)) @@ -1155,6 +1203,61 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El @test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)] @test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3] + @test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)] + @test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3] + + @testset "ambiguity check" begin + test_set = Any[StructArray([1;2+im]), + 1:2, + (1,2), + StructArray(@SArray [1;1+2im]), + (@SArray [1 2]), + 1] + tested_style = StructArrayStyle[] + dotaddsub((a, b, c),) = @. a + b - c + for as in Iterators.product(test_set, test_set, test_set) + if any(a -> a isa StructArray, as) + style = Broadcast.combine_styles(as...) + if !(style in tested_style) + push!(tested_style, style) + @test @inferred(dotaddsub(as))::StructArray == dotaddsub(map(collect, as)) + end + end + end + @test length(tested_style) == 4 + end + + @testset "allocation test" begin + a = StructArray{ComplexF64}(undef, 1) + allocated(a) = @allocated a .+ 1 + @test allocated(a) == 2allocated(a.re) + end + + @testset "StructStaticArray" begin + bclog(s) = log.(s) + test_allocated(f, s) = @test (@allocated f(s)) == 0 + a = @SMatrix [float(i) for i in 1:10, j in 1:10] + b = @SMatrix [0. for i in 1:10, j in 1:10] + s = StructArray{ComplexF64}((a , b)) + @test (@inferred bclog(s)) isa typeof(s) + test_allocated(bclog, s) + @test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix + bc = Base.broadcasted(+, s, s); + bc = Base.broadcasted(+, bc, bc, s); + @test @inferred(Broadcast.axes(bc)) === axes(s) + end + + @testset "StructJLArray" begin + bcabs(a) = abs.(a) + bcmul2(a) = 2 .* a + a = StructArray(randn(ComplexF32, 10, 10)) + sa = jl(a) + backend = StructArrays.GPUArraysCore.backend + @test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im) + @test collect(@inferred(bcabs(sa))) == bcabs(a) + @test @inferred(bcmul2(sa)) isa StructArray + @test (sa .+= 1) isa StructArray + end end @testset "map" begin