From 09540ac61c8956a4e48586e3031a6dad36d83683 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 20 Dec 2018 16:14:31 +0100 Subject: [PATCH] gradients for prod, cumsum, cumprod --- src/tracker/lib/array.jl | 75 ++++++++++++++++++++++++++++++++++++---- test/tracker.jl | 23 ++++++++++++ 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index a94323ca07..84c1b66702 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -279,15 +279,76 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) @grad sum(xs; dims = :) = sum(data(xs), dims = dims), Δ -> (zero(xs) .+ Δ, ) -Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) -Base.prod(xs::TrackedArray) = track(prod, xs) +Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) -@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,) -@grad prod(xs, dim) = prod(data(xs), dims = dim), - Δ -> (nobacksies(:sum, - reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ), - nothing) +@grad prod(xs; dims=:) = _prod(xs.data, prod(xs.data, dims=dims), dims) +_prod(xd, p, ::Colon) = p, Δ -> (nobacksies(:prod, ∇prod(xd, p, data(Δ)) ),) +_prod(xd, p, dims) = count(iszero, p) == 0 ? + (p, Δ -> (nobacksies(:prod, p ./ xd .* data(Δ) ),)) : + (p, Δ -> (nobacksies(:prod, mapslices(∇prod, xd; dims=dims) .* data(Δ)),)) + +function ∇prod(x, p=prod(x), Δ=1) + numzero = count(iszero, x) + if numzero == 0 + ∇ = p ./ x .* Δ + elseif numzero > 1 + ∇ = zero(x) + else + ∇ = ∇prod_one(x, Δ) + end +end +function ∇prod_one(x::Array, Δ) + zloc = findfirst(iszero, x) + ∇ = copy(x) + ∇[zloc] = 1 + nonzero = prod(∇) * Δ + ∇ .= 0 + ∇[zloc] = nonzero + ∇ +end +∇prod_one(x::AbstractArray, Δ) = ForwardDiff.gradient(y -> prod(y) * Δ, x) + +Base.cumsum(xs::TrackedArray; dims=1) = track(cumsum, xs; dims=dims) + +@grad cumsum(xs; dims=1) = _cumsum(xs.data, dims) +_cumsum(xd::Array, d) = cumsum(xd; dims=d), Δ -> (reverse(cumsum(reverse(Δ,dims=d),dims=d),dims=d),) +_cumsum(xd::AbstractArray, d) = cumsum(xd; dims=d), Δ -> (mapslices(reverse∘cumsum∘reverse,Δ, dims=d),) + +Base.cumprod(xs::TrackedArray; dims=nothing) = track(cumprod, xs; dims=dims) + +@grad cumprod(xs; dims=nothing) = _cumprod(xs.data, dims) +_cumprod(xd, ::Nothing, p = cumprod(xd)) = p, Δ -> (nobacksies(:cumprod, ∇cumprod(xd, p, data(Δ)) ),) +_cumprod(xd, d, p = cumprod(xd, dims=d)) = p, Δ -> (nobacksies(:cumprod, ∇cumprod_d(xd, Val(d), p, data(Δ)) ),) + +function ∇cumprod(x::Vector, p, Δ) + len = length(x) + z = something(findfirst(iszero, x), len+1) + ∇ = zero(x) + @inbounds for i=1:z-1 + ixi = 1/x[i] + for k=i:z-1 + ∇[i] += p[k] * Δ[k] * ixi + end + end + @inbounds if z != len+1 + pk = z==1 ? one(p[1]) : p[z-1] # will be prod(x[j] for j=1:k if j!=z) + ∇[z] += pk * Δ[z] + for k=(z+1):len + pk *= x[k] + ∇[z] += pk * Δ[k] + end + end + ∇ +end +∇cumprod(x::AbstractVector, p, Δ) = vec(Δ' * ForwardDiff.jacobian(cumprod, x)) +@noinline function ∇cumprod_d(x::AbstractArray{T,N}, ::Val{d}, p, Δ) where {T,N,d} + ∇ = similar(x) + for i in Iterators.product(ntuple(k -> k==d ? Ref(:) : axes(x,k), Val(N))...) + copyto!(view(∇,i...), ∇cumprod(x[i...], p[i...], Δ[i...])) + end + ∇ # roughly mapslices(∇cumprod, x,p,Δ; dims=d) if that existed +end Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) diff --git a/test/tracker.jl b/test/tracker.jl index 51f4ad964e..144b17cc7f 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -17,12 +17,35 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10)) @test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5)) + @test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5)) @test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3)) @test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3)) @test gradtest(x -> sum(x), randn(Float64,2,3)) + @test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5)) +@test gradtest(x -> prod(x, dims=1), (3,4,5)) +@test gradtest(x -> prod(x, dims=1), (3,)) @test gradtest(x -> prod(x), (3,4,5)) +@test gradtest(x -> prod(x), (3,)) + +rzero(dims...) = (x = rand(dims...); x[2]=0; x) +@test gradtest(x -> prod(x, dims=(2, 3)), rzero(3,4,5)) +@test gradtest(x -> prod(x, dims=1), rzero(3,4,5)) +@test gradtest(x -> prod(x, dims=1), rzero(3,)) +@test gradtest(x -> prod(x), rzero(3,4,5)) +@test gradtest(x -> prod(x), rzero(3,)) + +@test gradtest(x -> cumsum(x, dims=2), (3,4,5)) +@test gradtest(x -> cumsum(x, dims=1), (3,)) +@test gradtest(x -> cumsum(x), (3,)) + +@test gradtest(x -> cumprod(x, dims=2), (3,4,5)) +@test gradtest(x -> cumprod(x, dims=1), (3,)) +@test gradtest(x -> cumprod(x), (3,)) +@test gradtest(x -> cumprod(x, dims=2), rzero(3,4,5)) +@test gradtest(x -> cumprod(x, dims=1), rzero(3,)) +@test gradtest(x -> cumprod(x), rzero(3,)) @test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), (3,5))