Skip to content

Commit

Permalink
Use style dispatch in broadcast(!) (#393)
Browse files Browse the repository at this point in the history
Supersedes #295, relies on JuliaLang/julia#35620.
  • Loading branch information
N5N3 authored Mar 15, 2022
1 parent 9412fa1 commit 9f14a19
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export AbstractGPUArrayStyle

using Base.Broadcast

import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle
import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate

const BroadcastGPUArray{T} = Union{AnyGPUArray{T},
Base.RefValue{<:AbstractGPUArray{T}}}
Expand Down Expand Up @@ -47,7 +47,15 @@ end
copyto!(similar(bc, ElType), bc)
end

@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing})
@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle}
return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
end

@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict

@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc)

@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc′ = Broadcast.preprocess(dest, bc)
Expand All @@ -72,12 +80,6 @@ end
return dest
end

# Base defines this method as a performance optimization, but we don't know how to do
# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback
@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) =
copyto!(dest, convert(Broadcasted{Nothing}, bc))


## map

allequal(x) = true
Expand Down
32 changes: 32 additions & 0 deletions test/testsuite/broadcasting.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testsuite "broadcasting" (AT, eltypes)->begin
broadcasting(AT, eltypes)
vec3(AT, eltypes)
unknown_wrapper(AT, eltypes)

@testset "type instabilities" begin
f(x) = x ? 1.0 : 0
Expand Down Expand Up @@ -205,3 +206,34 @@ function vec3(AT, eltypes)
@test all(map((a,b)-> all((1,2,3) .≈ (1,2,3)), Array(res2), res2c))
end
end

# A help struct to test style-based broadcast dispatch with unknown array wrapper.
# `WrapArray(A)` behaves like `A` during broadcast. But its not a `BroadcastGPUArray`.
struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
data::P
end
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
Base.size(A::WrapArray) = size(A.data)
# For kernal support
Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data))
# For broadcast support
GPUArrays.backend(::Type{WrapArray{T,N,P}}) where {T,N,P} = GPUArrays.backend(P)
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)

function unknown_wrapper(AT, eltypes)
for ET in eltypes
@views @testset "unknown wrapper $ET" begin
A = AT(rand(ET, 10, 10))
WA = WrapArray(A)
# test for dispatch with src's BroadcastStyle.
@test Array(WA .+ ET(1)) == Array(A .+ ET(1))
@test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A)
@test Array(WA .+ A[:,1]) == Array(A .+ A[:,1])
@test Array(WA .+ A[1,:]) == Array(A .+ A[1,:])
# test for dispatch with dest's BroadcastStyle.
WA .= ET(1)
@test all(isequal(ET(1)), Array(A))
end
end
end

0 comments on commit 9f14a19

Please sign in to comment.