Skip to content

Commit

Permalink
auto-collect in forward
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Feb 4, 2019
1 parent 8380709 commit cfe6859
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/tracker/back.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ end

back(::Grads, ::Nothing, _) = return

collectmemaybe(xs) = xs

function forward(f, ps::Params)
y = f()
y = collectmemaybe(f())
y, function (Δ)
g = Grads(ps)
if istracked(y)
Expand Down
3 changes: 3 additions & 0 deletions src/tracker/lib/real.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,6 @@ end
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end

collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs)
collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs)
5 changes: 5 additions & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,9 @@ end
end == ([3, 2],)
end

@testset "Custom Sensitivities" begin
y, back = Tracker.forward(x -> [3x^2, 2x], 5)
@test back([1, 1]) == (32,)
end

end #testset

0 comments on commit cfe6859

Please sign in to comment.