Skip to content

Commit

Permalink
update Zygote.@ignore to ChainRulesCore.@ignore_derivatives
Browse files Browse the repository at this point in the history
Fixes #745
  • Loading branch information
ChrisRackauckas committed Jul 16, 2022
1 parent 80c4247 commit 0c94b7f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "1.51.2"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
24 changes: 13 additions & 11 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@ using Cassette
@reexport using Flux
@reexport using OptimizationOptimJL

import ChainRulesCore

gpu_or_cpu(x) = Array

# ForwardDiff integration

ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T
@assert length(ẋ) == 1
ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where {T}
@assert length(ẋ) == 1
ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
end

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T =
d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where {T} =
d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T =
d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where {T} =
d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl,y -> Tridiagonal(dl,zeros(length(d)),zeros(length(du)),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d,y -> Tridiagonal(zeros(length(dl)),d,zeros(length(du)),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl,y -> Tridiagonal(zeros(length(dl)),zeros(length(d),du),)
ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end,1:end-1]),diag(p̄),diag(p̄[1:end-1,2:end]))
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),)
ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end]))

include("ffjord.jl")
include("train.jl")
Expand Down
32 changes: 16 additions & 16 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,27 @@ end

function jacobian_fn(f, x::AbstractMatrix, args...)
y, back = Zygote.pullback(f, x)
z = Zygote.@ignore similar(y)
Zygote.@ignore fill!(z, zero(eltype(x)))
z = ChainRulesCore.@ignore_derivatives similar(y)
ChainRulesCore.@ignore_derivatives fill!(z, zero(eltype(x)))
vec = Zygote.Buffer(x, size(x, 1), size(x, 1), size(x, 2))
for i in 1:size(y, 1)
Zygote.@ignore z[i, :] .= one(eltype(x))
ChainRulesCore.@ignore_derivatives z[i, :] .= one(eltype(x))
vec[i, :, :] = back(z)[1]
Zygote.@ignore z[i, :] .= zero(eltype(x))
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(eltype(x))
end
copy(vec)
end

function jacobian_fn(f::Lux.Chain, x::AbstractMatrix, args...)
p,st = args
y, back = Zygote.pullback((z,ps,s)->f(z,ps,s)[1], x, p, st)
z = Zygote.@ignore similar(y)
Zygote.@ignore fill!(z, zero(eltype(x)))
z = ChainRulesCore.@ignore_derivatives similar(y)
ChainRulesCore.@ignore_derivatives fill!(z, zero(eltype(x)))
vec = Zygote.Buffer(x, size(x, 1), size(x, 1), size(x, 2))
for i in 1:size(y, 1)
Zygote.@ignore z[i, :] .= one(eltype(x))
ChainRulesCore.@ignore_derivatives z[i, :] .= one(eltype(x))
vec[i, :, :] = back(z)[1]
Zygote.@ignore z[i, :] .= zero(eltype(x))
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(eltype(x))
end
copy(vec)
end
Expand Down Expand Up @@ -282,17 +282,17 @@ function forward_ffjord(n::FFJORD, x, p=n.p, e=randn(eltype(x), size(x));
ffjord_(u, p, t) = ffjord(u, p, t, n.re, e, n.st; regularize, monte_carlo)
# ffjord_(u, p, t) = ffjord(u, p, t, n.re, e; regularize, monte_carlo)
if regularize
_z = Zygote.@ignore similar(x, 3, size(x, 2))
Zygote.@ignore fill!(_z, zero(eltype(x)))
_z = ChainRulesCore.@ignore_derivatives similar(x, 3, size(x, 2))
ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x)))
prob = ODEProblem{false}(ffjord_, vcat(x, _z), n.tspan, p)
pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end]
z = pred[1:end - 3, :]
delta_logp = pred[end - 2:end - 2, :]
λ₁ = pred[end - 1, :]
λ₂ = pred[end, :]
else
_z = Zygote.@ignore similar(x, 1, size(x, 2))
Zygote.@ignore fill!(_z, zero(eltype(x)))
_z = ChainRulesCore.@ignore_derivatives similar(x, 1, size(x, 2))
ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x)))
prob = ODEProblem{false}(ffjord_, vcat(x, _z), n.tspan, p)
pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end]
z = pred[1:end - 1, :]
Expand All @@ -313,14 +313,14 @@ function backward_ffjord(n::FFJORD, n_samples, p=n.p, e=randn(eltype(n.model[1].
sensealg = InterpolatingAdjoint()
ffjord_(u, p, t) = ffjord(u, p, t, n.re, e, n.st; regularize, monte_carlo)
if regularize
_z = Zygote.@ignore similar(x, 3, size(x, 2))
Zygote.@ignore fill!(_z, zero(eltype(x)))
_z = ChainRulesCore.@ignore_derivatives similar(x, 3, size(x, 2))
ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x)))
prob = ODEProblem{false}(ffjord_, vcat(x, _z), reverse(n.tspan), p)
pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end]
z = pred[1:end - 3, :]
else
_z = Zygote.@ignore similar(x, 1, size(x, 2))
Zygote.@ignore fill!(_z, zero(eltype(x)))
_z = ChainRulesCore.@ignore_derivatives similar(x, 1, size(x, 2))
ChainRulesCore.@ignore_derivatives fill!(_z, zero(eltype(x)))
prob = ODEProblem{false}(ffjord_, vcat(x, _z), reverse(n.tspan), p)
pred = solve(prob, n.args...; sensealg, n.kwargs...)[:, :, end]
z = pred[1:end - 1, :]
Expand Down

0 comments on commit 0c94b7f

Please sign in to comment.