Skip to content

Commit

Permalink
Add GPU broadcast support.
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Aug 24, 2022
1 parent 5db3da5 commit 2644c24
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ version = "0.6.12"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
StaticArraysCore = "1.1"
StaticArrays = "1.5.4"
StaticArraysCore = "1.3"
StaticArrays = "1.5.6"
GPUArraysCore = "~0.1.2"
Tables = "1"
julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
8 changes: 8 additions & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,12 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore: backend
function backend(::Type{T}) where {T<:StructArray}
backs = map(backend, fieldtypes(array_types(T)))
all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!")
return backs[1]
end

end # module
11 changes: 11 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings
using TypedTables: Table
using DataAPI: refarray, refvalue
using Adapt: adapt, Adapt
using JLArrays
using Test

using Documenter: doctest
Expand Down Expand Up @@ -1178,6 +1179,16 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
@test (@inferred bclog(s)) isa typeof(s)
test_allocated(bclog, s)
end

@testset "StructJLArray" begin
bcabs(a) = abs.(a)
bcmul2(a) = 2 .* a
a = StructArray(randn(ComplexF32, 10, 10))
sa = jl(a)
@test collect(@inferred(bcabs(sa))) == bcabs(a)
@test @inferred(bcmul2(sa)) isa StructArray
@test (sa .+= 1) isa StructArray
end
end

@testset "map" begin
Expand Down

0 comments on commit 2644c24

Please sign in to comment.