From 07f1f94dac577e411932616c9afd8bca04fa692e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 19 May 2021 13:55:05 -0400 Subject: [PATCH 1/2] tuple un-broadcast --- src/lib/broadcast.jl | 2 ++ test/features.jl | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 1219d7883..ddadfcbcd 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -46,6 +46,7 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr end trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) +trim(x::Tuple, Δ) = ntuple(k -> Δ[k], length(x)) unbroadcast(x::AbstractArray, x̄) = size(x) == size(x̄) ? x̄ : @@ -55,6 +56,7 @@ unbroadcast(x::AbstractArray, x̄) = unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) +unbroadcast(x::Tuple, x̄) = trim(x, accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 unbroadcast(x::AbstractArray, x̄::Nothing) = nothing diff --git a/test/features.jl b/test/features.jl index 48df0c87c..254034202 100644 --- a/test/features.jl +++ b/test/features.jl @@ -481,3 +481,14 @@ end Zygote.gradient(loss_adjoint,[1.0]) @test x[1] == x[2] end + +@testset "tuples & broadcasting" begin + @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),) + @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),) + + # https://github.com/FluxML/Zygote.jl/issues/975 + gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2)) + gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2]) + @test gt[1] == gv[1] + @test collect(gt[2]) ≈ gv[2] +end From 33f1d6de9f1fddae3a0fc166f66e921b2e219d9b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 20 May 2021 08:37:01 -0400 Subject: [PATCH 2/2] skip the sum, sometimes --- src/lib/broadcast.jl | 2 +- test/features.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index ddadfcbcd..451d8794c 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -56,7 +56,7 @@ unbroadcast(x::AbstractArray, x̄) = unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) -unbroadcast(x::Tuple, x̄) = trim(x, accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 +unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 unbroadcast(x::AbstractArray, x̄::Nothing) = nothing diff --git a/test/features.jl b/test/features.jl index 254034202..6843acbf6 100644 --- a/test/features.jl +++ b/test/features.jl @@ -485,6 +485,7 @@ end @testset "tuples & broadcasting" begin @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),) @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),) + @test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),) # https://github.com/FluxML/Zygote.jl/issues/975 gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))