Skip to content

Commit

Permalink
Make Broadcast.result_style work on styles with fields. (JuliaLang#50938
Browse files Browse the repository at this point in the history
)

Fixes JuliaLang#50937.

---------

Co-authored-by: Jameson Nash <vtjnash@gmail.com>
  • Loading branch information
tpapp and vtjnash committed Aug 18, 2023
1 parent 710df70 commit 169f435
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
7 changes: 5 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@ Base.Broadcast.DefaultArrayStyle{1}()
function result_style end

result_style(s::BroadcastStyle) = s
result_style(s1::S, s2::S) where S<:BroadcastStyle = S()
function result_style(s1::S, s2::S) where S<:BroadcastStyle
s1 s2 ? s1 : error("inconsistent broadcast styles, custom rule needed")
end
# Test both orders so users typically only have to declare one order
result_style(s1, s2) = result_join(s1, s2, BroadcastStyle(s1, s2), BroadcastStyle(s2, s1))

Expand All @@ -457,7 +459,8 @@ result_join(::Any, ::Any, s::BroadcastStyle, ::Unknown) = s
result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::Unknown, ::Unknown) =
ArrayConflict()
# Fallbacks in case users define `rule` for both argument-orders (not recommended)
result_join(::Any, ::Any, ::S, ::S) where S<:BroadcastStyle = S()
result_join(::Any, ::Any, s1::S, s2::S) where S<:BroadcastStyle = result_style(s1, s2)

@noinline function result_join(::S, ::T, ::U, ::V) where {S,T,U,V}
error("""
conflicting broadcast rules defined
Expand Down
13 changes: 13 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,19 @@ end
@test CartesianIndex(1,2) .+ [CartesianIndex(3,4), CartesianIndex(5,6)] == [CartesianIndex(4, 6), CartesianIndex(6, 8)]
end

struct MyBroadcastStyleWithField <: Broadcast.BroadcastStyle
i::Int
end
# asymmetry intended
Base.BroadcastStyle(a::MyBroadcastStyleWithField, b::MyBroadcastStyleWithField) = a

@testset "issue #50937: styles that have fields" begin
@test Broadcast.result_style(MyBroadcastStyleWithField(1), MyBroadcastStyleWithField(1)) ==
MyBroadcastStyleWithField(1)
@test_throws ErrorException Broadcast.result_style(MyBroadcastStyleWithField(1),
MyBroadcastStyleWithField(2))
end

# test that `Broadcast` definition is defined as total and eligible for concrete evaluation
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle
@test Base.infer_effects(BroadcastStyle, (DefaultArrayStyle{1},DefaultArrayStyle{2},)) |>
Expand Down

0 comments on commit 169f435

Please sign in to comment.