-
Notifications
You must be signed in to change notification settings - Fork 220
/
Copy pathad.jl
248 lines (213 loc) · 7.93 KB
/
ad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
##############################
# Global variables/constants #
##############################
const ADBACKEND = Ref(:forward_diff)
function setadbackend(backend_sym)
@assert backend_sym == :forward_diff || backend_sym == :reverse_diff
backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = backend_sym
end
const ADSAFE = Ref(false)
function setadsafe(switch::Bool)
@info("[Turing]: global ADSAFE is set as $switch")
ADSAFE[] = switch
end
const CHUNKSIZE = Ref(40) # default chunksize used by AD
function setchunksize(chunk_size::Int)
if ~(CHUNKSIZE[] == chunk_size)
@info("[Turing]: AD chunk size is set as $chunk_size")
CHUNKSIZE[] = chunk_size
end
end
abstract type ADBackend end
struct ForwardDiffAD{chunk} <: ADBackend end
getchunksize(::T) where {T <: ForwardDiffAD} = getchunksize(T)
getchunksize(::Type{ForwardDiffAD{chunk}}) where chunk = chunk
getchunksize(::T) where {T <: Sampler} = getchunksize(T)
getchunksize(::Type{<:Sampler{T}}) where {T} = getchunksize(T)
getchunksize(::Nothing) = getchunksize(Nothing)
getchunksize(::Type{Nothing}) = CHUNKSIZE[]
struct FluxTrackerAD <: ADBackend end
ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))
function ADBackend(::Val{T}) where {T}
if T === :forward_diff
return ForwardDiffAD{CHUNKSIZE[]}
else
return FluxTrackerAD
end
end
"""
getADtype(alg)
Finds the autodifferentiation type of the algorithm `alg`.
"""
getADtype(::Nothing) = getADtype(Nothing)
getADtype(::Type{Nothing}) = getADtype()
getADtype() = ADBackend()
getADtype(s::Sampler) = getADtype(typeof(s))
getADtype(s::Type{<:Sampler{TAlg}}) where {TAlg} = getADtype(TAlg)
"""
gradient(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
Computes the gradient of the log joint of `θ` for the model specified by
`(vi, sampler, model)` using whichever automatic differentation tool is currently active.
"""
function gradient(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::TS,
) where {TS <: Sampler}
ad_type = getADtype(TS)
if ad_type <: ForwardDiffAD
return gradient_forward(θ, vi, model, sampler)
else ad_type <: FluxTrackerAD
return gradient_reverse(θ, vi, model, sampler)
end
end
"""
gradient_forward(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
spl::Union{Nothing, Sampler}=nothing,
)
Computes the gradient of the log joint of `θ` for the model specified by `(vi, spl, model)`
using forwards-mode AD from ForwardDiff.jl.
"""
function gradient_forward(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
# Record old parameters.
vals_old, logp_old = copy(vi.vals), copy(vi.logp)
# Define function to compute log joint.
function f(θ)
vi[sampler] = θ
return -runmodel!(model, vi, sampler).logp
end
chunk_size = getchunksize(sampler)
# Set chunk size and do ForwardMode.
chunk = ForwardDiff.Chunk(min(length(θ), chunk_size))
config = ForwardDiff.GradientConfig(f, θ, chunk)
∂l∂θ = ForwardDiff.gradient!(similar(θ), f, θ, config)
l = vi.logp.value
# Replace old parameters to ensure this function doesn't mutate `vi`.
vi.vals, vi.logp = vals_old, logp_old
# Strip tracking info from θ to avoid mutating it.
θ .= ForwardDiff.value.(θ)
return l, ∂l∂θ
end
"""
gradient_reverse(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
Computes the gradient of the log joint of `θ` for the model specified by
`(vi, sampler, model)` using reverse-mode AD from Flux.jl.
"""
function gradient_reverse(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::Union{Nothing, Sampler}=nothing,
)
vals_old, logp_old = copy(vi.vals), copy(vi.logp)
# Specify objective function.
function f(θ)
vi[sampler] = θ
return -runmodel!(model, vi, sampler).logp
end
# Compute forward and reverse passes.
l_tracked, ȳ = Tracker.forward(f, θ)
l, ∂l∂θ = Tracker.data(l_tracked), Tracker.data(ȳ(1)[1])
# Remove tracking info from variables in model (because mutable state).
vi.vals, vi.logp = vals_old, logp_old
# Strip tracking info from θ to avoid mutating it.
θ .= Tracker.data.(θ)
# Return non-tracked gradient value
return l, ∂l∂θ
end
function verifygrad(grad::AbstractVector{<:Real})
if any(isnan, grad) || any(isinf, grad)
@warn("Numerical error has been found in gradients.")
@warn("grad = $(grad)")
return false
else
return true
end
end
import StatsFuns: binomlogpdf
binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int) = Tracker.track(binomlogpdf, n, p, x)
Tracker.@grad function binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int)
return binomlogpdf(n, Tracker.data(p), x),
Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing)
end
import StatsFuns: nbinomlogpdf
# Note the definition of NegativeBinomial in Julia is not the same as Wikipedia's.
# Check the docstring of NegativeBinomial, r is the number of successes and
# k is the number of failures
_nbinomlogpdf_grad_1(r, p, k) = sum(1 / (k + r - i) for i in 1:k) + log(p)
_nbinomlogpdf_grad_2(r, p, k) = -k / (1 - p) + r / p
nbinomlogpdf(n::Tracker.TrackedReal, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
nbinomlogpdf(n::Real, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
nbinomlogpdf(n::Tracker.TrackedReal, p::Real, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Tracker.TrackedReal, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing)
end
Tracker.@grad function nbinomlogpdf(r::Real, p::Tracker.TrackedReal, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->(Tracker._zero(r), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing)
end
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Real, k::Int)
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Tracker._zero(p), nothing)
end
import StatsFuns: poislogpdf
poislogpdf(v::Tracker.TrackedReal, x::Int) = Tracker.track(poislogpdf, v, x)
Tracker.@grad function poislogpdf(v::Tracker.TrackedReal, x::Int)
return poislogpdf(Tracker.data(v), x),
Δ->(Δ * (x/v - 1), nothing)
end
function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T}
FD = ForwardDiff.Dual{T}
val = ForwardDiff.value(p)
Δ = ForwardDiff.partials(p)
return FD(binomlogpdf(n, val, x), Δ * (x / val - (n - x) / (1 - val)))
end
function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::ForwardDiff.Dual{T}, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_p = ForwardDiff.value(p)
val_r = ForwardDiff.value(r)
Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, val_p, k)
Δ_p = ForwardDiff.partials(p) * _nbinomlogpdf_grad_2(val_r, val_p, k)
Δ = Δ_p + Δ_r
return FD(nbinomlogpdf(val_r, val_p, k), Δ)
end
function nbinomlogpdf(r::Real, p::ForwardDiff.Dual{T}, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_p = ForwardDiff.value(p)
Δ_p = ForwardDiff.partials(p) * _nbinomlogpdf_grad_2(r, val_p, k)
return FD(nbinomlogpdf(r, val_p, k), Δ_p)
end
function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T}
FD = ForwardDiff.Dual{T}
val_r = ForwardDiff.value(r)
Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, p, k)
return FD(nbinomlogpdf(val_r, p, k), Δ_r)
end
function poislogpdf(v::ForwardDiff.Dual{T}, x::Int) where {T}
FD = ForwardDiff.Dual{T}
val = ForwardDiff.value(v)
Δ = ForwardDiff.partials(v)
return FD(poislogpdf(val, x), Δ * (x/val - 1))
end