Skip to content

Commit

Permalink
ad workaround for nbinomlogpdf
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Jan 27, 2019
1 parent aed8f5c commit 7b2d934
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ Tracker.@grad function binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int)
Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing)
end

import StatsFuns: nbinomlogpdf
nbinomlogpdf(n::Tracker.TrackedReal, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Tracker.TrackedReal, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->* (sum(1 / (k + r - i) for i in 1:k) + log(1 - p)), Δ * (-r / (1 - p) + k / p), nothing)
end

import StatsFuns: poislogpdf
poislogpdf(v::Tracker.TrackedReal, x::Int) = Tracker.track(poislogpdf, v, x)
Expand All @@ -195,6 +201,17 @@ function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T}
return FD(binomlogpdf(n, val, x), Δ * (x / val - (n - x) / (1 - val)))
end

function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::ForwardDiff.Dual{T}, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_p = ForwardDiff.value(p)
val_r = ForwardDiff.value(r)

Δ_p = ForwardDiff.partials(p) * (-val_r / (1 - val_p) + k / val_p)
Δ_r = ForwardDiff.partials(r) * (sum(1 / (k + val_r - i) for i in 1:k) + log(1 - val_p))
Δ = Δ_p + Δ_r
return FD(nbinomlogpdf(val_r, val_p, k), Δ)
end

function poislogpdf(v::ForwardDiff.Dual{T}, x::Int) where {T}
FD = ForwardDiff.Dual{T}
val = ForwardDiff.value(v)
Expand Down

0 comments on commit 7b2d934

Please sign in to comment.