diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 8694b394c..df9fdc9b5 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -46,6 +46,7 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr end trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) +trim(x::Tuple, Δ) = ntuple(k -> Δ[k], length(x)) unbroadcast(x::AbstractArray, x̄) = size(x) == size(x̄) ? x̄ : @@ -55,6 +56,7 @@ unbroadcast(x::AbstractArray, x̄) = unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) +unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1 unbroadcast(x::AbstractArray, x̄::Nothing) = nothing diff --git a/test/features.jl b/test/features.jl index 48df0c87c..6843acbf6 100644 --- a/test/features.jl +++ b/test/features.jl @@ -481,3 +481,15 @@ end Zygote.gradient(loss_adjoint,[1.0]) @test x[1] == x[2] end + +@testset "tuples & broadcasting" begin + @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),) + @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),) + @test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),) + + # https://github.com/FluxML/Zygote.jl/issues/975 + gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2)) + gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2]) + @test gt[1] == gv[1] + @test collect(gt[2]) ≈ gv[2] +end