From 9f14a1940aab2e8540e21e1f443609f4c295a3bf Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 15 Mar 2022 22:49:42 +0800 Subject: [PATCH] Use style dispatch in broadcast(!) (#393) Supersedes #295, relies on JuliaLang/julia#35620. --- src/host/broadcast.jl | 18 ++++++++++-------- test/testsuite/broadcasting.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 661e2dcd..7a037985 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -4,7 +4,7 @@ export AbstractGPUArrayStyle using Base.Broadcast -import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle +import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate const BroadcastGPUArray{T} = Union{AnyGPUArray{T}, Base.RefValue{<:AbstractGPUArray{T}}} @@ -47,7 +47,15 @@ end copyto!(similar(bc, ElType), bc) end -@inline function Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) +@inline function Base.materialize!(::Style, dest, bc::Broadcasted) where {Style<:AbstractGPUArrayStyle} + return _copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) +end + +@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{Nothing}) = _copyto!(dest, bc) # Keep it for ArrayConflict + +@inline Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractGPUArrayStyle}) = _copyto!(dest, bc) + +@inline function _copyto!(dest::AbstractArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest bc′ = Broadcast.preprocess(dest, bc) @@ -72,12 +80,6 @@ end return dest end -# Base defines this method as a performance optimization, but we don't know how to do -# `fill!` in general for all `BroadcastGPUArray` so we just go straight to the fallback -@inline Base.copyto!(dest::BroadcastGPUArray, bc::Broadcasted{<:Broadcast.AbstractArrayStyle{0}}) = - copyto!(dest, convert(Broadcasted{Nothing}, bc)) - - ## map allequal(x) = true diff --git a/test/testsuite/broadcasting.jl b/test/testsuite/broadcasting.jl index 939af21b..5eda7ae5 100644 --- a/test/testsuite/broadcasting.jl +++ b/test/testsuite/broadcasting.jl @@ -1,6 +1,7 @@ @testsuite "broadcasting" (AT, eltypes)->begin broadcasting(AT, eltypes) vec3(AT, eltypes) + unknown_wrapper(AT, eltypes) @testset "type instabilities" begin f(x) = x ? 1.0 : 0 @@ -205,3 +206,34 @@ function vec3(AT, eltypes) @test all(map((a,b)-> all((1,2,3) .≈ (1,2,3)), Array(res2), res2c)) end end + +# A help struct to test style-based broadcast dispatch with unknown array wrapper. +# `WrapArray(A)` behaves like `A` during broadcast. But its not a `BroadcastGPUArray`. +struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} + data::P +end +Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...] +Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...) +Base.size(A::WrapArray) = size(A.data) +# For kernal support +Adapt.adapt_structure(to, s::WrapArray) = WrapArray(Adapt.adapt(to, s.data)) +# For broadcast support +GPUArrays.backend(::Type{WrapArray{T,N,P}}) where {T,N,P} = GPUArrays.backend(P) +Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P) + +function unknown_wrapper(AT, eltypes) + for ET in eltypes + @views @testset "unknown wrapper $ET" begin + A = AT(rand(ET, 10, 10)) + WA = WrapArray(A) + # test for dispatch with src's BroadcastStyle. + @test Array(WA .+ ET(1)) == Array(A .+ ET(1)) + @test Array(WA .+ WA) == Array(WA .+ A) == Array(A .+ A) + @test Array(WA .+ A[:,1]) == Array(A .+ A[:,1]) + @test Array(WA .+ A[1,:]) == Array(A .+ A[1,:]) + # test for dispatch with dest's BroadcastStyle. + WA .= ET(1) + @test all(isequal(ET(1)), Array(A)) + end + end +end