Skip to content

Commit

Permalink
Resolve slowness caused by use of vi.logp
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Feb 19, 2018
1 parent 78b6ed6 commit d641f42
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 36 deletions.
18 changes: 12 additions & 6 deletions example-models/stan-models/lda.model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@
# z[n] ~ Categorical(theta[doc[n]])
# end

phi_dot_theta = [log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V]) for m = 1:M]
# phi_dot_theta = [log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V]) for m = 1:M]

# @simd for m=1:M
# @inbounds phi_dot_theta[m] = log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V])
# end
Expand All @@ -67,13 +68,18 @@
# # w[n] ~ Categorical(phi_dot_theta)
# Turing.acclogp!(vi, phi_dot_theta[doc[n]][w[n]])
#end

# lp = mapreduce(n->phi_dot_theta[doc[n]][w[n]], +, 1:N)
lp = phi_dot_theta[doc[1]][w[1]]
@simd for n = 2:N
@inbounds lp += phi_dot_theta[doc[n]][w[n]]
end
Turing.acclogp!(vi, lp)

# lp = phi_dot_theta[doc[1]][w[1]]
# @simd for n = 2:N
# @inbounds lp += phi_dot_theta[doc[n]][w[n]]
# end

phi_dot_theta = log.(hcat(phi...) * hcat(theta...))
_lp += mapreduce(n->phi_dot_theta[w[n], doc[n]], +, 1:N)

# Turing.acclogp!(vi, lp)
end


Expand Down
4 changes: 2 additions & 2 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ gradient_r(theta::Vector, vi::Turing.VarInfo, model::Function, spl::Union{Void,
for i = 1:length(ipts)
vi_spl[i] = ipts[i]
end
vi.logp = 0
# vi.logp = 0
-runmodel(model, vi, spl).logp
end
gtape = GradientTape(f_r, inputs)
Expand All @@ -138,7 +138,7 @@ gradient_r(theta::Vector, vi::Turing.VarInfo, model::Function, spl::Union{Void,
# grad = ReverseDiff.gradient(x -> (vi[spl] = x; -runmodel(model, vi, spl).logp), inputs)

# vi[spl] = realpart(vi[spl])
vi.logp = 0
# vi.logp = 0

grad
end
17 changes: 11 additions & 6 deletions src/core/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ macro ~(left, right)
# Call observe
esc(
quote
Turing.observe(
_lp += Turing.observe(
sampler,
$(right), # Distribution
$(left), # Data point
Expand All @@ -60,7 +60,7 @@ macro ~(left, right)
esc(
quote
# Call observe
Turing.observe(
_lp += Turing.observe(
sampler,
$(right), # Distribution
$(left), # Data point
Expand All @@ -86,20 +86,22 @@ macro ~(left, right)
assume_ex = quote
vn = Turing.VarName(vi, $syms, "")
if isa($(right), Vector)
$(left) = Turing.assume(
$(left), __lp = Turing.assume(
sampler,
$(right), # dist
vn, # VarName
$(left),
vi # VarInfo
)
_lp += __lp
else
$(left) = Turing.assume(
$(left), __lp = Turing.assume(
sampler,
$(right), # dist
vn, # VarName
vi # VarInfo
)
_lp += __lp
end
end
else
Expand All @@ -108,12 +110,13 @@ macro ~(left, right)
csym_str = string(Turing._compiler_[:fname]) * string(@__LINE__)
indexing = reduce(*, "", map(idx -> string(idx), idcs))
vn = Turing.VarName(vi, Symbol(csym_str), sym, indexing)
$left = Turing.assume(
$(left), __lp = Turing.assume(
sampler,
$right, # dist
vn, # VarName
vi # VarInfo
)
_lp += __lp
end
end
esc(assume_ex)
Expand Down Expand Up @@ -208,7 +211,7 @@ macro model(fexpr)

# Modify fbody, so that we always return VarInfo
fbody_inner = deepcopy(fbody)

return_ex = fbody.args[end] # get last statement of defined model
if typeof(return_ex) == Symbol
pop!(fbody_inner.args)
Expand All @@ -232,6 +235,8 @@ macro model(fexpr)
# NOTE: code above is commented out to disable explict return
end

unshift!(fbody_inner.args, :(_lp = zero(Real)))
push!(fbody_inner.args, :(vi.logp = _lp))
push!(fbody_inner.args, Expr(:return, :vi)) # always return vi in the end of function body

dprintln(1, fbody_inner)
Expand Down
9 changes: 5 additions & 4 deletions src/samplers/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,9 @@ assume{T<:Hamiltonian}(spl::Sampler{T}, dist::Distribution, vn::VarName, vi::Var
dprintln(2, "assuming...")
updategid!(vi, vn, spl)
r = vi[vn]
acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
r
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
# r
r, logpdf_with_trans(dist, r, istrans(vi, vn))
end

assume{A<:Hamiltonian,D<:Distribution}(spl::Sampler{A}, dists::Vector{D}, vn::VarName, var::Any, vi::VarInfo) = begin
Expand All @@ -181,7 +182,7 @@ assume{A<:Hamiltonian,D<:Distribution}(spl::Sampler{A}, dists::Vector{D}, vn::Va

rs = vi[vns] # NOTE: inside Turing the Julia conversion should be sticked to

acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))
# acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))

if isa(dist, UnivariateDistribution) || isa(dist, MatrixDistribution)
@assert size(var) == size(rs) "[assume] variable and random number dimension unmatched"
Expand All @@ -200,7 +201,7 @@ assume{A<:Hamiltonian,D<:Distribution}(spl::Sampler{A}, dists::Vector{D}, vn::Va
end
end

var
var, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1])))
end

observe{A<:Hamiltonian}(spl::Sampler{A}, d::Distribution, value::Any, vi::VarInfo) =
Expand Down
8 changes: 5 additions & 3 deletions src/samplers/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ end
assume(spl::Sampler{IS}, dist::Distribution, vn::VarName, vi::VarInfo) = begin
r = rand(dist)
push!(vi, vn, r, dist, 0)
r
r, zero(Real)
end

observe(spl::Sampler{IS}, dist::Distribution, value::Any, vi::VarInfo) =
acclogp!(vi, logpdf(dist, value))
observe(spl::Sampler{IS}, dist::Distribution, value::Any, vi::VarInfo) = begin
# acclogp!(vi, logpdf(dist, value))
logpdf(dist, value)
end
4 changes: 2 additions & 2 deletions src/samplers/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ assume(spl::Sampler{MH}, dist::Distribution, vn::VarName, vi::VarInfo) = begin
r = vi[vn]
end

acclogp!(vi, logpdf(dist, r)) # accumulate pdf of prior
r
# acclogp!(vi, logpdf(dist, r)) # accumulate pdf of prior
r, logpdf(dist, r)
end

assume{D<:Distribution}(spl::Sampler{MH}, dists::Vector{D}, vn::VarName, var::Any, vi::VarInfo) =
Expand Down
10 changes: 5 additions & 5 deletions src/samplers/pgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,15 @@ assume{T<:Union{PG,SMC}}(spl::Sampler{T}, dist::Distribution, vn::VarName, _::Va
r = rand(dist)
push!(vi, vn, r, dist, spl.alg.gid)
spl.info[:cache_updated] = CACHERESET # sanity flag mask for getidcs and getranges
r
elseif is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = rand(dist)
vi[vn] = vectorize(dist, r)
setgid!(vi, spl.alg.gid, vn)
setorder!(vi, vn, vi.num_produce)
r
else
updategid!(vi, vn, spl)
vi[vn]
r = vi[vn]
end
else # vn belongs to other sampler <=> conditionning on vn
if haskey(vi, vn)
Expand All @@ -167,15 +165,17 @@ assume{T<:Union{PG,SMC}}(spl::Sampler{T}, dist::Distribution, vn::VarName, _::Va
push!(vi, vn, r, dist, -1)
end
acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
r
end
r, zero(Real)
end

assume{A<:Union{PG,SMC},D<:Distribution}(spl::Sampler{A}, dists::Vector{D}, vn::VarName, var::Any, vi::VarInfo) =
error("[Turing] PG and SMC doesn't support vectorizing assume statement")

observe{T<:Union{PG,SMC}}(spl::Sampler{T}, dist::Distribution, value, vi) =
observe{T<:Union{PG,SMC}}(spl::Sampler{T}, dist::Distribution, value, vi) = begin
produce(logpdf(dist, value))
zero(Real)
end

observe{A<:Union{PG,SMC},D<:Distribution}(spl::Sampler{A}, ds::Vector{D}, value::Any, vi::VarInfo) =
error("[Turing] PG and SMC doesn't support vectorizing observe statement")
17 changes: 10 additions & 7 deletions src/samplers/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ assume(spl::Void, dist::Distribution, vn::VarName, vi::VarInfo) = begin
end
# NOTE: The importance weight is not correctly computed here because
# r is genereated from some uniform distribution which is different from the prior
acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
r
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
r, logpdf_with_trans(dist, r, istrans(vi, vn))
end

assume{T<:Distribution}(spl::Void, dists::Vector{T}, vn::VarName, var::Any, vi::VarInfo) = begin
Expand Down Expand Up @@ -83,23 +83,26 @@ assume{T<:Distribution}(spl::Void, dists::Vector{T}, vn::VarName, var::Any, vi::
end
end

acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))
# acclogp!(vi, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1]))))

var
var, sum(logpdf_with_trans(dist, rs, istrans(vi, vns[1])))
end

observe(spl::Void, dist::Distribution, value::Any, vi::VarInfo) = begin
vi.num_produce += 1
acclogp!(vi, logpdf(dist, value))
# acclogp!(vi, logpdf(dist, value))
logpdf(dist, value)
end

observe{T<:Distribution}(spl::Void, dists::Vector{T}, value::Any, vi::VarInfo) = begin
@assert length(dists) == 1 "[observe] Turing only support vectorizing i.i.d distribution"
dist = dists[1]
@assert isa(dist, UnivariateDistribution) || isa(dist, MultivariateDistribution) "[observe] vectorizing matrix distribution is not supported"
if isa(dist, UnivariateDistribution) # only univariate distributions support broadcast operation (logpdf.) by Distributions.jl
acclogp!(vi, sum(logpdf.(dist, value)))
# acclogp!(vi, sum(logpdf.(dist, value)))
sum(logpdf.(dist, value))
else
acclogp!(vi, sum(logpdf(dist, value)))
# acclogp!(vi, sum(logpdf(dist, value)))
sum(logpdf(dist, value))
end
end
16 changes: 15 additions & 1 deletion src/samplers/support/hmc_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,21 @@ leapfrog(_θ::Union{Vector,SubArray}, p::Vector{Float64}, τ::Int, ϵ::Float64,
elseif ADBACKEND == :reverse_diff
grad = gradient_r(θ, vi, model, spl)
end
verifygrad(grad) || (vi[spl] = θ_old; setlogp!(vi, old_logp); θ = θ_old; p = p_old; break)
# verifygrad(grad) || (vi[spl] = θ_old; setlogp!(vi, old_logp); θ = θ_old; p = p_old; break)
if ~verifygrad(grad)
if ADBACKEND == :forward_diff
vi[spl] = θ_old
elseif ADBACKEND == :reverse_diff
vi_spl = vi[spl]
for i = 1:length(θ_old)
vi_spl[i].value = θ_old[i]
end
end
setlogp!(vi, old_logp)
θ = θ_old
p = p_old
break
end

p -= ϵ * grad / 2

Expand Down

0 comments on commit d641f42

Please sign in to comment.