diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 4baaef87c8..ef65ecb656 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -147,8 +147,10 @@ end back(::Grads, ::Nothing, _) = return +collectmemaybe(xs) = xs + function forward(f, ps::Params) - y = f() + y = collectmemaybe(f()) y, function (Δ) g = Grads(ps) if istracked(y) diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index be6f62f0dc..ec57f0d3d6 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -155,3 +155,6 @@ end function back_(g::Grads, c::Call{typeof(collect)}, Δ) foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ) end + +collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs) +collectmemaybe(xs::AbstractArray{<:TrackedReal}) = collect(xs) diff --git a/test/tracker.jl b/test/tracker.jl index 02aff6dd92..bb64f01a7e 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -323,4 +323,9 @@ end end == ([3, 2],) end +@testset "Custom Sensitivities" begin + y, back = Tracker.forward(x -> [3x^2, 2x], 5) + @test back([1, 1]) == (32,) +end + end #testset