Skip to content

Commit

Permalink
Merge pull request FluxML#977 from mcabbott/tuplecast
Browse files Browse the repository at this point in the history
`unbroadcast` for tuples
  • Loading branch information
CarloLucibello authored May 22, 2021
2 parents f8191ca + 33f1d6d commit bcc3921
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
trim(x::Tuple, Δ) = ntuple(k -> Δ[k], length(x))

unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ?:
Expand All @@ -55,6 +56,7 @@ unbroadcast(x::AbstractArray, x̄) =
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand Down
12 changes: 12 additions & 0 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,15 @@ end
Zygote.gradient(loss_adjoint,[1.0])
@test x[1] == x[2]
end

@testset "tuples & broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)

# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
@test gt[1] == gv[1]
@test collect(gt[2]) gv[2]
end

0 comments on commit bcc3921

Please sign in to comment.