diff --git a/src/core/ad.jl b/src/core/ad.jl index 292ff4a91..88c468c6a 100644 --- a/src/core/ad.jl +++ b/src/core/ad.jl @@ -134,6 +134,13 @@ Tracker.@grad function poislogpdf(v::Tracker.TrackedReal, x::Int) Δ->(Δ * (x/v - 1), nothing) end +function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T} + FD = ForwardDiff.Dual{T} + val = ForwardDiff.value(p) + Δ = ForwardDiff.partials(p) + return FD(binomlogpdf(n, val, x), Δ * (x / val - (n - x) / (1 - val))) +end + function poislogpdf(v::ForwardDiff.Dual{T}, x::Int) where {T} FD = ForwardDiff.Dual{T} val = ForwardDiff.value(v) diff --git a/test/ad.jl/AD_compatibility_with_distributions.jl b/test/ad.jl/AD_compatibility_with_distributions.jl index d65a5b985..9b3e3b7fe 100644 --- a/test/ad.jl/AD_compatibility_with_distributions.jl +++ b/test/ad.jl/AD_compatibility_with_distributions.jl @@ -37,6 +37,12 @@ let rtol=1e-8, atol=1e-8, ) + @test isapprox( + Tracker.gradient(foo, 0.5)[1], + ForwardDiff.derivative(foo, 0.5); + rtol=1e-8, + atol=1e-8, + ) bar = p->logpdf(Binomial(10, p), 3) @test isapprox( @@ -45,22 +51,40 @@ let rtol=1e-8, atol=1e-8, ) + @test isapprox( + Tracker.gradient(bar, 0.5)[1], + ForwardDiff.derivative(bar, 0.5), + rtol=1e-8, + atol=1e-8, + ) end let - foo = p->poislogpdf(1, p) + foo = p->Turing.poislogpdf(p, 1) @test isapprox( Tracker.gradient(foo, 0.5)[1], central_fdm(5, 1)(foo, 0.5); rtol=1e-8, atol=1e-8, ) + @test isapprox( + Tracker.gradient(foo, 0.5)[1], + ForwardDiff.derivative(foo, 0.5); + rtol=1e-8, + atol=1e-8, + ) - bar = p->logpdf(Poisson(1), 3) + bar = p->logpdf(Poisson(p), 3) @test isapprox( Tracker.gradient(bar, 0.5)[1], central_fdm(5, 1)(bar, 0.5); rtol=1e-8, atol=1e-8, ) + @test isapprox( + Tracker.gradient(bar, 0.5)[1], + ForwardDiff.derivative(bar, 0.5); + rtol=1e-8, + atol=1e-8, + ) end diff --git a/test/ad.jl/adr.jl b/test/ad.jl/adr.jl index 6024428a2..7f7819ced 100644 --- a/test/ad.jl/adr.jl +++ b/test/ad.jl/adr.jl @@ -23,8 +23,8 @@ mvn = collect(Iterators.filter(vn -> vn.sym == :m, keys(vi)))[1] _s = getval(vi, svn)[1] _m = getval(vi, mvn)[1] -x = map(_->Float64(_), vi[nothing]) -∇E = gradient_reverse(x, vi, ad_test_f) +x = map(x->Float64(x), vi[nothing]) +∇E = gradient_reverse(x, vi, ad_test_f)[2] # println(vi.vns) # println(∇E) grad_Turing = sort(∇E) @@ -37,7 +37,7 @@ function logp(x::Vector) # s = invlink(dist_s, s) m = x[1] lik_dist = Normal(m, sqrt(s)) - lp = logpdf(dist_s, s, false) + logpdf(Normal(0,sqrt(s)), m, false) + lp = logpdf(dist_s, s) + logpdf(Normal(0,sqrt(s)), m) lp += logpdf(lik_dist, 1.5) + logpdf(lik_dist, 2.0) lp end diff --git a/test/ad.jl/skip_tests b/test/ad.jl/skip_tests index fac158d71..3d088bb70 100644 --- a/test/ad.jl/skip_tests +++ b/test/ad.jl/skip_tests @@ -1,3 +1 @@ # tests to skip -adr.jl -AD_compatibility_with_distributions.jl