diff --git a/Project.toml b/Project.toml index 9cf6fcd0..3427ba9b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,19 +5,22 @@ version = "0.6.12" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Adapt = "1, 2, 3" DataAPI = "1" -StaticArraysCore = "1.1" -StaticArrays = "1.5.4" +StaticArraysCore = "1.3" +StaticArrays = "1.5.6" +GPUArraysCore = "~0.1.2" Tables = "1" julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 6fed453e..f7431992 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -29,4 +29,12 @@ end import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) +# for GPU broadcast +import GPUArraysCore: backend +function backend(::Type{T}) where {T<:StructArray} + backs = map(backend, fieldtypes(array_types(T))) + all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!") + return backs[1] +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index ccbba8a4..a33f72d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings using TypedTables: Table using DataAPI: refarray, refvalue using Adapt: adapt, Adapt +using JLArrays using Test using Documenter: doctest @@ -1178,6 +1179,16 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test (@inferred bclog(s)) isa typeof(s) test_allocated(bclog, s) end + + @testset "StructJLArray" begin + bcabs(a) = abs.(a) + bcmul2(a) = 2 .* a + a = StructArray(randn(ComplexF32, 10, 10)) + sa = jl(a) + @test collect(@inferred(bcabs(sa))) == bcabs(a) + @test @inferred(bcmul2(sa)) isa StructArray + @test (sa .+= 1) isa StructArray + end end @testset "map" begin