From b10c5a6a673b113dfaf7c85edaf8c8a3769b89eb Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 15 Aug 2022 11:14:08 -0700 Subject: [PATCH 1/2] remove NTuple --- Project.toml | 2 +- src/rulesets/Base/broadcast.jl | 4 ++-- test/rulesets/Base/broadcast.jl | 4 ++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index a420d3ed9..15ce2ddb5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.1" +version = "1.44.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index be11eb76a..ddf4dc426 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -328,13 +328,13 @@ end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N} - val = if length(x) == length(dx) + val = if N == length(dx) dx else sum(dx; dims=2:ndims(dx)) end eltype(val) <: AbstractZero && return NoTangent() - return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent + return ProjectTo(x)(Tuple{Vararg{Any,N}}(val)) # Tangent end unbroadcast(x::Tuple, dx::AbstractZero) = dx diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 68d47a7d4..b153b7cc5 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -173,4 +173,8 @@ BT1 = Broadcast.BroadcastStyle(Tuple) test_rrule(copy∘broadcasted, complex, rand()) end end + + @testset "bugs" begin + @test ChainRules.unbroadcast((1,2,[3]), [4,5,[6]]) isa Tangent # earlier, NTuple demanded same type + end end \ No newline at end of file From c1a66ee211d0eb03b9922786b6039cd76be6c454 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 15 Aug 2022 12:22:02 -0700 Subject: [PATCH 2/2] spaces Co-authored-by: Frames Catherine White --- test/rulesets/Base/broadcast.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index b153b7cc5..219b45a71 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -175,6 +175,6 @@ BT1 = Broadcast.BroadcastStyle(Tuple) end @testset "bugs" begin - @test ChainRules.unbroadcast((1,2,[3]), [4,5,[6]]) isa Tangent # earlier, NTuple demanded same type + @test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type end end \ No newline at end of file