Skip to content

Commit

Permalink
Improve implementaion of transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
xukai92 committed Feb 16, 2018
1 parent b311346 commit c4ad9a1
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 67 deletions.
6 changes: 3 additions & 3 deletions benchmarks/lda.run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ println("Stan time: ", lda_time)
setchunksize(60)
setadbackend(:reverse_diff)

tbenchmark("HMC(2, 0.025, 10)", "ldamodel", "data=ldastandata[1]")
# tbenchmark("HMC(2, 0.025, 10)", "ldamodel", "data=ldastandata[1]")

turnprogress(false)

for (modelc, modeln) in zip([
"ldamodel_vec",
# "ldamodel_vec",
"ldamodel"
], [
"LDA-vec",
# "LDA-vec",
"LDA"
])
tbenchmark("HMC(2, 0.025, 10)", modelc, "data=ldastandata[1]")
Expand Down
14 changes: 14 additions & 0 deletions benchmarks/profile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using HDF5, JLD, ProfileView

using Turing
setadbackend(:reverse_diff)
turnprogress(false)

include(Pkg.dir("Turing")*"/example-models/stan-models/lda-stan.data.jl")
include(Pkg.dir("Turing")*"/example-models/stan-models/lda.model.jl")

sample(ldamodel(data=ldastandata[1]), HMC(2, 0.025, 10))
Profile.clear()
@profile sample(ldamodel(data=ldastandata[1]), HMC(2000, 0.025, 10))

ProfileView.svgwrite("ldamodel.svg")
24 changes: 17 additions & 7 deletions example-models/stan-models/lda.model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,39 @@
# phi[k] = Vector{Real}(V)
# end

# phi_dot_theta = [Vector{Real}(V) for m=1:M]

@model ldamodel(K, V, M, N, w, doc, beta, alpha) = begin
theta = Vector{Vector{Real}}(M)
for m = 1:M
theta[m] ~ Dirichlet(alpha)
@simd for m = 1:M
@inbounds theta[m] ~ Dirichlet(alpha)
end

phi = Vector{Vector{Real}}(K)
for k = 1:K
phi[k] ~ Dirichlet(beta)
@simd for k = 1:K
@inbounds phi[k] ~ Dirichlet(beta)
end

# z = tzeros(Int, N)
# for n = 1:N
# z[n] ~ Categorical(theta[doc[n]])
# end

phi_dot_theta = [log.([dot(map(p -> p[i], phi), theta[m]) for i = 1:V]) for m=1:M]
phi_dot_theta = [log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V]) for m = 1:M]
# @simd for m=1:M
# @inbounds phi_dot_theta[m] = log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V])
# end

#for n = 1:N
# # phi_dot_theta = [dot(map(p -> p[i], phi), theta[doc[n]]) for i = 1:V]
# # phi_dot_theta = [dot(map(p -> p[v], phi), theta[doc[n]]) for v = 1:V]
# # w[n] ~ Categorical(phi_dot_theta)
# Turing.acclogp!(vi, phi_dot_theta[doc[n]][w[n]])
#end
lp = mapreduce(n->phi_dot_theta[doc[n]][w[n]], +, 1:N)
# lp = mapreduce(n->phi_dot_theta[doc[n]][w[n]], +, 1:N)
lp = phi_dot_theta[doc[1]][w[1]]
@simd for n = 2:N
@inbounds lp += phi_dot_theta[doc[n]][w[n]]
end
Turing.acclogp!(vi, lp)

end
Expand Down
58 changes: 37 additions & 21 deletions src/core/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type VarInfo
vns :: Vector{VarName}
ranges :: Vector{UnitRange{Int}}
vals :: Vector{Real}
rs :: Dict{VarName,Any}
rs :: Dict{Union{VarName,Vector{VarName}},Any}
dists :: Vector{Distributions.Distribution}
gids :: Vector{Int}
logp :: Real
Expand All @@ -58,7 +58,7 @@ type VarInfo

VarInfo() = begin
vals = Vector{Real}()
rs = Dict{VarName,Any}()
rs = Dict{Union{VarName,Vector{VarName}},Any}()
logp = zero(Real)
pred = Dict{Symbol,Any}()
flags = Dict{String,Vector{Bool}}()
Expand Down Expand Up @@ -162,35 +162,51 @@ syms(vi::VarInfo) = map(vn -> vn.sym, vns(vi)) # get all symbols
Base.getindex(vi::VarInfo, vn::VarName) = begin
@assert haskey(vi, vn) "[Turing] attempted to replay unexisting variables in VarInfo"
dist = getdist(vi, vn)
# if isa(dist, SimplexDistribution) # Reduce memory allocation for distributions with simplex constraints
# if vn in keys(vi.rs)
# r = vi.rs[vn]
# else
# r = reconstruct(dist, getval(vi, vn))
# r_real = similar(r, Real)
# r_real[eachindex(r_real)] = r
# vi.rs[vn] = r_real
# r = r_real
# end
# reconstruct!(r, dist, getval(vi, vn))
# istrans(vi, vn) ?
# invlink!(r, dist, r) :
# r
# else
if isa(dist, SimplexDistribution) # Reduce memory allocation for distributions with simplex constraints
if vn in keys(vi.rs)
r = vi.rs[vn]
reconstruct!(r, dist, getval(vi, vn))
else
r = reconstruct(dist, getval(vi, vn))
r_real = similar(r, Real)
r_real[:] = r
vi.rs[vn] = r_real
r = r_real
end
istrans(vi, vn) ?
invlink!(r, dist, r) :
r
else
istrans(vi, vn) ?
invlink(dist, reconstruct(dist, getval(vi, vn))) :
reconstruct(dist, getval(vi, vn))
# end
end
end

Base.setindex!(vi::VarInfo, val::Any, vn::VarName) = setval!(vi, val, vn)

Base.getindex(vi::VarInfo, vns::Vector{VarName}) = begin
@assert haskey(vi, vns[1]) "[Turing] attempted to replay unexisting variables in VarInfo"
dist = getdist(vi, vns[1])
istrans(vi, vns[1]) ?
invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) :
reconstruct(dist, getval(vi, vns), length(vns))
# if isa(dist, SimplexDistribution) # Reduce memory allocation for distributions with simplex constraints
# if vns in keys(vi.rs)
# r = vi.rs[vns]
# reconstruct!(r, dist, getval(vi, vn))
# else
# r = reconstruct(dist, getval(vi, vns), length(vns))
# r_real = similar(r, Real)
# r_real[:] = r
# vi.rs[vns] = r_real
# r = r_real
# end
# istrans(vi, vns[1]) ?
# invlink!(r, dist, r) :
# r
# else
istrans(vi, vns[1]) ?
invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) :
reconstruct(dist, getval(vi, vns), length(vns))
# end
end

# NOTE: vi[vview] will just return what insdie vi (no transformations applied)
Expand Down
4 changes: 2 additions & 2 deletions src/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@inline reconstruct(d::MatrixDistribution, val::Union{Vector,SubArray}, T::Type) = Array{T, 2}(reshape(val, size(d)...))

@inline reconstruct!(r, d::Distribution, val::Union{Vector,SubArray}) = reconstruct!(r, d, val, typeof(val[1]))
@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type) = (r[eachindex(r)] = val; r)
@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type) = (r[:] = val; r)


@inline reconstruct(d::Distribution, val::Union{Vector,SubArray}, n::Int) = reconstruct(d, val, typeof(val[1]), n)
Expand All @@ -55,4 +55,4 @@
end

@inline reconstruct!(r, d::Distribution, val::Union{Vector,SubArray}, n::Int) = reconstruct!(r, d, val, typeof(val[1]), n)
@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type, n::Int) = (r[eachindex(r)] = val; r)
@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type, n::Int) = (r[:] = val; r)
1 change: 0 additions & 1 deletion src/samplers/support/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
value = Dict{Symbol, Any}() # value is named here because of Sample has a field called value
for vn in keys(vi)
value[sym(vn)] = realpart(vi[vn])
# value[sym(vn)] = realpart(vi[vn])
end

# NOTE: do we need to check if lp is 0?
Expand Down
97 changes: 64 additions & 33 deletions src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,29 @@ link{T}(d::SimplexDistribution, x::Vector{T}) = link!(similar(x), d, x)
link!{T}(y, d::SimplexDistribution, x::Vector{T}) = begin
K = length(x)

key = (:cache_vec, T, K - 1)
if key in keys(TRANS_CACHE)
z = TRANS_CACHE[key]
else
z = Vector{T}(K - 1)
TRANS_CACHE[key] = z
end

for k in 1:K-1
z[k] = x[k] / (one(T) - sum(x[1:k-1]))
end

@simd for k = 1:K-1
@inbounds y[k] = logit(z[k]) - log(one(T) / (K-k))
# key = (:cache_vec, T, K - 1)
# if key in keys(TRANS_CACHE)
# z = TRANS_CACHE[key]
# else
# z = Vector{T}(K - 1)
# TRANS_CACHE[key] = z
# end

# for k in 1:K-1
# z[k] = x[k] / (one(T) - sum(x[1:k-1]))
# end

# @simd for k = 1:K-1
# @inbounds y[k] = logit(z[k]) - log(one(T) / (K-k))
# end

sum_tmp = zero(T)
z = x[1]
y[1] = logit(z) - log(one(T) / (K-1))
@simd for k in 2:K-1
@inbounds sum_tmp += x[k-1]
@inbounds z = x[k] / (one(T) - sum_tmp)
@inbounds y[k] = logit(z) - log(one(T) / (K-k))
end

y
Expand Down Expand Up @@ -179,24 +188,28 @@ end
invlink{T}(d::SimplexDistribution, y::Vector{T}) = invlink!(similar(y), d, y)
invlink!{T}(x, d::SimplexDistribution, y::Vector{T}) = begin
K = length(y)

key = (:cache_vec, T, K - 1)
if key in keys(TRANS_CACHE)
z = TRANS_CACHE[key]
else
z = Vector{T}(K - 1)
TRANS_CACHE[key] = z
end

@simd for k = 1:K-1
@inbounds z[k] = invlogit(y[k] + log(one(T) / (K - k)))
end
# @simd for k = 1:K-1
# @inbounds z[k] = invlogit(y[k] + log(one(T) / (K - k)))
# end

for k in 1:K-1
x[k] = (one(T) - sum(x[1:k-1])) * z[k]
# for k in 1:K-1
# x[k] = (one(T) - sum(x[1:k-1])) * z[k]
# end
# x[K] = one(T) - sum(x[1:K-1])

z = invlogit(y[1] + log(one(T) / (K - 1)))
x[1] = z
sum_tmp = zero(T)
@simd for k = 2:K-1
@inbounds z = invlogit(y[k] + log(one(T) / (K - k)))
@inbounds sum_tmp += x[k-1]
@inbounds x[k] = (one(T) - sum_tmp) * z
end
x[K] = one(T) - sum(x[1:K-1])
@assert sum(x) == 1 "[Turing] invlink of simplex distribution invalid"
sum_tmp += x[K-1]
x[K] = one(T) - sum_tmp

# @assert sum(x) == 1 "[Turing.invlink!] simplex distribution invalid; x = $x"
x
end

Expand All @@ -222,18 +235,36 @@ invlink!{T<:Real}(X, d::SimplexDistribution, Y::Matrix{T}) = begin
end
X[K,:] = one(T) - sum(X[1:K-1,:], 1)

# X[1,:] = Z[1,:]'
# sum_tmp = 0
# for k = 2:K-1
# sum_tmp += X[k-1,:]
# X[k,:] = (one(T) - sum_tmp') .* Z[k,:]
# end
# sum_tmp += X[K-1,:]
# X[K,:] = one(T) - sum_tmp'

X
end

logpdf_with_trans{T}(d::SimplexDistribution, x::Vector{T}, transform::Bool) = begin
lp = logpdf(d, x)
if transform
K = length(x)
z = Vector{T}(K-1)
@simd for k in 1:K-1
@inbounds z[k] = x[k] / (one(T) - sum(x[1:k-1]))

# @simd for k in 1:K-1
# @inbounds z[k] = x[k] / (one(T) - sum(x[1:k-1]))
# end
# lp += sum([log(z[k]) + log(one(T) - z[k]) + log(one(T) - sum(x[1:k-1])) for k in 1:K-1])

sum_tmp = zero(T)
z = x[1]
lp += log(z) + log(one(T) - z)
@simd for k in 2:K-1
@inbounds sum_tmp += x[k-1]
@inbounds z = x[k] / (one(T) - sum_tmp)
@inbounds lp += log(z) + log(one(T) - z) + log(one(T) - sum_tmp)
end
lp += sum([log(z[k]) + log(one(T) - z[k]) + log(one(T) - sum(x[1:k-1])) for k in 1:K-1])
end
lp
end
Expand Down

0 comments on commit c4ad9a1

Please sign in to comment.