Skip to content

Commit

Permalink
Fix and add more AD tests (#625)
Browse files Browse the repository at this point in the history
* add binomlogpdf ForwardDiff.Dual workaround

* fix broken AD tests and add some more

* remove FDM from REQUIRE
  • Loading branch information
mohamed82008 authored and yebai committed Dec 14, 2018
1 parent 3c8ea47 commit 873aaec
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
7 changes: 7 additions & 0 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ Tracker.@grad function poislogpdf(v::Tracker.TrackedReal, x::Int)
Δ->* (x/v - 1), nothing)
end

function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T}
FD = ForwardDiff.Dual{T}
val = ForwardDiff.value(p)
Δ = ForwardDiff.partials(p)
return FD(binomlogpdf(n, val, x), Δ * (x / val - (n - x) / (1 - val)))
end

function poislogpdf(v::ForwardDiff.Dual{T}, x::Int) where {T}
FD = ForwardDiff.Dual{T}
val = ForwardDiff.value(v)
Expand Down
28 changes: 26 additions & 2 deletions test/ad.jl/AD_compatibility_with_distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ let
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(foo, 0.5)[1],
ForwardDiff.derivative(foo, 0.5);
rtol=1e-8,
atol=1e-8,
)

bar = p->logpdf(Binomial(10, p), 3)
@test isapprox(
Expand All @@ -45,22 +51,40 @@ let
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(bar, 0.5)[1],
ForwardDiff.derivative(bar, 0.5),
rtol=1e-8,
atol=1e-8,
)
end

let
foo = p->poislogpdf(1, p)
foo = p->Turing.poislogpdf(p, 1)
@test isapprox(
Tracker.gradient(foo, 0.5)[1],
central_fdm(5, 1)(foo, 0.5);
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(foo, 0.5)[1],
ForwardDiff.derivative(foo, 0.5);
rtol=1e-8,
atol=1e-8,
)

bar = p->logpdf(Poisson(1), 3)
bar = p->logpdf(Poisson(p), 3)
@test isapprox(
Tracker.gradient(bar, 0.5)[1],
central_fdm(5, 1)(bar, 0.5);
rtol=1e-8,
atol=1e-8,
)
@test isapprox(
Tracker.gradient(bar, 0.5)[1],
ForwardDiff.derivative(bar, 0.5);
rtol=1e-8,
atol=1e-8,
)
end
6 changes: 3 additions & 3 deletions test/ad.jl/adr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ mvn = collect(Iterators.filter(vn -> vn.sym == :m, keys(vi)))[1]
_s = getval(vi, svn)[1]
_m = getval(vi, mvn)[1]

x = map(_->Float64(_), vi[nothing])
∇E = gradient_reverse(x, vi, ad_test_f)
x = map(x->Float64(x), vi[nothing])
∇E = gradient_reverse(x, vi, ad_test_f)[2]
# println(vi.vns)
# println(∇E)
grad_Turing = sort(∇E)
Expand All @@ -37,7 +37,7 @@ function logp(x::Vector)
# s = invlink(dist_s, s)
m = x[1]
lik_dist = Normal(m, sqrt(s))
lp = logpdf(dist_s, s, false) + logpdf(Normal(0,sqrt(s)), m, false)
lp = logpdf(dist_s, s) + logpdf(Normal(0,sqrt(s)), m)
lp += logpdf(lik_dist, 1.5) + logpdf(lik_dist, 2.0)
lp
end
Expand Down
2 changes: 0 additions & 2 deletions test/ad.jl/skip_tests
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
# tests to skip
adr.jl
AD_compatibility_with_distributions.jl

0 comments on commit 873aaec

Please sign in to comment.