diff --git a/benchmarks/lda.run.jl b/benchmarks/lda.run.jl index c229d049d..2f0f839a2 100644 --- a/benchmarks/lda.run.jl +++ b/benchmarks/lda.run.jl @@ -14,15 +14,15 @@ println("Stan time: ", lda_time) setchunksize(60) setadbackend(:reverse_diff) -tbenchmark("HMC(2, 0.025, 10)", "ldamodel", "data=ldastandata[1]") +# tbenchmark("HMC(2, 0.025, 10)", "ldamodel", "data=ldastandata[1]") turnprogress(false) for (modelc, modeln) in zip([ - "ldamodel_vec", + # "ldamodel_vec", "ldamodel" ], [ - "LDA-vec", + # "LDA-vec", "LDA" ]) tbenchmark("HMC(2, 0.025, 10)", modelc, "data=ldastandata[1]") diff --git a/benchmarks/profile.jl b/benchmarks/profile.jl new file mode 100644 index 000000000..82c50dfa2 --- /dev/null +++ b/benchmarks/profile.jl @@ -0,0 +1,14 @@ +using HDF5, JLD, ProfileView + +using Turing +setadbackend(:reverse_diff) +turnprogress(false) + +include(Pkg.dir("Turing")*"/example-models/stan-models/lda-stan.data.jl") +include(Pkg.dir("Turing")*"/example-models/stan-models/lda.model.jl") + +sample(ldamodel(data=ldastandata[1]), HMC(2, 0.025, 10)) +Profile.clear() +@profile sample(ldamodel(data=ldastandata[1]), HMC(2000, 0.025, 10)) + +ProfileView.svgwrite("ldamodel.svg") \ No newline at end of file diff --git a/example-models/stan-models/lda.model.jl b/example-models/stan-models/lda.model.jl index c705ba4fe..c4b7c39a2 100644 --- a/example-models/stan-models/lda.model.jl +++ b/example-models/stan-models/lda.model.jl @@ -39,15 +39,17 @@ # phi[k] = Vector{Real}(V) # end +# phi_dot_theta = [Vector{Real}(V) for m=1:M] + @model ldamodel(K, V, M, N, w, doc, beta, alpha) = begin theta = Vector{Vector{Real}}(M) - for m = 1:M - theta[m] ~ Dirichlet(alpha) + @simd for m = 1:M + @inbounds theta[m] ~ Dirichlet(alpha) end phi = Vector{Vector{Real}}(K) - for k = 1:K - phi[k] ~ Dirichlet(beta) + @simd for k = 1:K + @inbounds phi[k] ~ Dirichlet(beta) end # z = tzeros(Int, N) @@ -55,13 +57,21 @@ # z[n] ~ Categorical(theta[doc[n]]) # end - phi_dot_theta = [log.([dot(map(p -> p[i], phi), theta[m]) for i = 1:V]) for m=1:M] + phi_dot_theta = [log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V]) for m = 1:M] + # @simd for m=1:M + # @inbounds phi_dot_theta[m] = log.([dot(map(p -> p[v], phi), theta[m]) for v = 1:V]) + # end + #for n = 1:N - # # phi_dot_theta = [dot(map(p -> p[i], phi), theta[doc[n]]) for i = 1:V] + # # phi_dot_theta = [dot(map(p -> p[v], phi), theta[doc[n]]) for v = 1:V] # # w[n] ~ Categorical(phi_dot_theta) # Turing.acclogp!(vi, phi_dot_theta[doc[n]][w[n]]) #end - lp = mapreduce(n->phi_dot_theta[doc[n]][w[n]], +, 1:N) + # lp = mapreduce(n->phi_dot_theta[doc[n]][w[n]], +, 1:N) + lp = phi_dot_theta[doc[1]][w[1]] + @simd for n = 2:N + @inbounds lp += phi_dot_theta[doc[n]][w[n]] + end Turing.acclogp!(vi, lp) end diff --git a/src/core/varinfo.jl b/src/core/varinfo.jl index 191d827f9..9c57889ca 100644 --- a/src/core/varinfo.jl +++ b/src/core/varinfo.jl @@ -47,7 +47,7 @@ type VarInfo vns :: Vector{VarName} ranges :: Vector{UnitRange{Int}} vals :: Vector{Real} - rs :: Dict{VarName,Any} + rs :: Dict{Union{VarName,Vector{VarName}},Any} dists :: Vector{Distributions.Distribution} gids :: Vector{Int} logp :: Real @@ -58,7 +58,7 @@ type VarInfo VarInfo() = begin vals = Vector{Real}() - rs = Dict{VarName,Any}() + rs = Dict{Union{VarName,Vector{VarName}},Any}() logp = zero(Real) pred = Dict{Symbol,Any}() flags = Dict{String,Vector{Bool}}() @@ -162,25 +162,25 @@ syms(vi::VarInfo) = map(vn -> vn.sym, vns(vi)) # get all symbols Base.getindex(vi::VarInfo, vn::VarName) = begin @assert haskey(vi, vn) "[Turing] attempted to replay unexisting variables in VarInfo" dist = getdist(vi, vn) - # if isa(dist, SimplexDistribution) # Reduce memory allocation for distributions with simplex constraints - # if vn in keys(vi.rs) - # r = vi.rs[vn] - # else - # r = reconstruct(dist, getval(vi, vn)) - # r_real = similar(r, Real) - # r_real[eachindex(r_real)] = r - # vi.rs[vn] = r_real - # r = r_real - # end - # reconstruct!(r, dist, getval(vi, vn)) - # istrans(vi, vn) ? - # invlink!(r, dist, r) : - # r - # else + if isa(dist, SimplexDistribution) # Reduce memory allocation for distributions with simplex constraints + if vn in keys(vi.rs) + r = vi.rs[vn] + reconstruct!(r, dist, getval(vi, vn)) + else + r = reconstruct(dist, getval(vi, vn)) + r_real = similar(r, Real) + r_real[:] = r + vi.rs[vn] = r_real + r = r_real + end + istrans(vi, vn) ? + invlink!(r, dist, r) : + r + else istrans(vi, vn) ? invlink(dist, reconstruct(dist, getval(vi, vn))) : reconstruct(dist, getval(vi, vn)) - # end + end end Base.setindex!(vi::VarInfo, val::Any, vn::VarName) = setval!(vi, val, vn) @@ -188,9 +188,25 @@ Base.setindex!(vi::VarInfo, val::Any, vn::VarName) = setval!(vi, val, vn) Base.getindex(vi::VarInfo, vns::Vector{VarName}) = begin @assert haskey(vi, vns[1]) "[Turing] attempted to replay unexisting variables in VarInfo" dist = getdist(vi, vns[1]) - istrans(vi, vns[1]) ? - invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) : - reconstruct(dist, getval(vi, vns), length(vns)) + # if isa(dist, SimplexDistribution) # Reduce memory allocation for distributions with simplex constraints + # if vns in keys(vi.rs) + # r = vi.rs[vns] + # reconstruct!(r, dist, getval(vi, vn)) + # else + # r = reconstruct(dist, getval(vi, vns), length(vns)) + # r_real = similar(r, Real) + # r_real[:] = r + # vi.rs[vns] = r_real + # r = r_real + # end + # istrans(vi, vns[1]) ? + # invlink!(r, dist, r) : + # r + # else + istrans(vi, vns[1]) ? + invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) : + reconstruct(dist, getval(vi, vns), length(vns)) + # end end # NOTE: vi[vview] will just return what insdie vi (no transformations applied) diff --git a/src/helper.jl b/src/helper.jl index 980da0e63..6d4cefcd6 100644 --- a/src/helper.jl +++ b/src/helper.jl @@ -38,7 +38,7 @@ @inline reconstruct(d::MatrixDistribution, val::Union{Vector,SubArray}, T::Type) = Array{T, 2}(reshape(val, size(d)...)) @inline reconstruct!(r, d::Distribution, val::Union{Vector,SubArray}) = reconstruct!(r, d, val, typeof(val[1])) -@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type) = (r[eachindex(r)] = val; r) +@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type) = (r[:] = val; r) @inline reconstruct(d::Distribution, val::Union{Vector,SubArray}, n::Int) = reconstruct(d, val, typeof(val[1]), n) @@ -55,4 +55,4 @@ end @inline reconstruct!(r, d::Distribution, val::Union{Vector,SubArray}, n::Int) = reconstruct!(r, d, val, typeof(val[1]), n) -@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type, n::Int) = (r[eachindex(r)] = val; r) \ No newline at end of file +@inline reconstruct!(r, d::MultivariateDistribution, val::Union{Vector,SubArray}, T::Type, n::Int) = (r[:] = val; r) \ No newline at end of file diff --git a/src/samplers/support/helper.jl b/src/samplers/support/helper.jl index 1cb3d5af8..b33ac58a8 100644 --- a/src/samplers/support/helper.jl +++ b/src/samplers/support/helper.jl @@ -7,7 +7,6 @@ value = Dict{Symbol, Any}() # value is named here because of Sample has a field called value for vn in keys(vi) value[sym(vn)] = realpart(vi[vn]) - # value[sym(vn)] = realpart(vi[vn]) end # NOTE: do we need to check if lp is 0? diff --git a/src/transform.jl b/src/transform.jl index 216e72f4d..27d02e142 100644 --- a/src/transform.jl +++ b/src/transform.jl @@ -133,20 +133,29 @@ link{T}(d::SimplexDistribution, x::Vector{T}) = link!(similar(x), d, x) link!{T}(y, d::SimplexDistribution, x::Vector{T}) = begin K = length(x) - key = (:cache_vec, T, K - 1) - if key in keys(TRANS_CACHE) - z = TRANS_CACHE[key] - else - z = Vector{T}(K - 1) - TRANS_CACHE[key] = z - end - - for k in 1:K-1 - z[k] = x[k] / (one(T) - sum(x[1:k-1])) - end - - @simd for k = 1:K-1 - @inbounds y[k] = logit(z[k]) - log(one(T) / (K-k)) + # key = (:cache_vec, T, K - 1) + # if key in keys(TRANS_CACHE) + # z = TRANS_CACHE[key] + # else + # z = Vector{T}(K - 1) + # TRANS_CACHE[key] = z + # end + + # for k in 1:K-1 + # z[k] = x[k] / (one(T) - sum(x[1:k-1])) + # end + + # @simd for k = 1:K-1 + # @inbounds y[k] = logit(z[k]) - log(one(T) / (K-k)) + # end + + sum_tmp = zero(T) + z = x[1] + y[1] = logit(z) - log(one(T) / (K-1)) + @simd for k in 2:K-1 + @inbounds sum_tmp += x[k-1] + @inbounds z = x[k] / (one(T) - sum_tmp) + @inbounds y[k] = logit(z) - log(one(T) / (K-k)) end y @@ -179,24 +188,28 @@ end invlink{T}(d::SimplexDistribution, y::Vector{T}) = invlink!(similar(y), d, y) invlink!{T}(x, d::SimplexDistribution, y::Vector{T}) = begin K = length(y) - - key = (:cache_vec, T, K - 1) - if key in keys(TRANS_CACHE) - z = TRANS_CACHE[key] - else - z = Vector{T}(K - 1) - TRANS_CACHE[key] = z - end - @simd for k = 1:K-1 - @inbounds z[k] = invlogit(y[k] + log(one(T) / (K - k))) - end + # @simd for k = 1:K-1 + # @inbounds z[k] = invlogit(y[k] + log(one(T) / (K - k))) + # end - for k in 1:K-1 - x[k] = (one(T) - sum(x[1:k-1])) * z[k] + # for k in 1:K-1 + # x[k] = (one(T) - sum(x[1:k-1])) * z[k] + # end + # x[K] = one(T) - sum(x[1:K-1]) + + z = invlogit(y[1] + log(one(T) / (K - 1))) + x[1] = z + sum_tmp = zero(T) + @simd for k = 2:K-1 + @inbounds z = invlogit(y[k] + log(one(T) / (K - k))) + @inbounds sum_tmp += x[k-1] + @inbounds x[k] = (one(T) - sum_tmp) * z end - x[K] = one(T) - sum(x[1:K-1]) - @assert sum(x) == 1 "[Turing] invlink of simplex distribution invalid" + sum_tmp += x[K-1] + x[K] = one(T) - sum_tmp + + # @assert sum(x) == 1 "[Turing.invlink!] simplex distribution invalid; x = $x" x end @@ -222,6 +235,15 @@ invlink!{T<:Real}(X, d::SimplexDistribution, Y::Matrix{T}) = begin end X[K,:] = one(T) - sum(X[1:K-1,:], 1) + # X[1,:] = Z[1,:]' + # sum_tmp = 0 + # for k = 2:K-1 + # sum_tmp += X[k-1,:] + # X[k,:] = (one(T) - sum_tmp') .* Z[k,:] + # end + # sum_tmp += X[K-1,:] + # X[K,:] = one(T) - sum_tmp' + X end @@ -229,11 +251,20 @@ logpdf_with_trans{T}(d::SimplexDistribution, x::Vector{T}, transform::Bool) = be lp = logpdf(d, x) if transform K = length(x) - z = Vector{T}(K-1) - @simd for k in 1:K-1 - @inbounds z[k] = x[k] / (one(T) - sum(x[1:k-1])) + + # @simd for k in 1:K-1 + # @inbounds z[k] = x[k] / (one(T) - sum(x[1:k-1])) + # end + # lp += sum([log(z[k]) + log(one(T) - z[k]) + log(one(T) - sum(x[1:k-1])) for k in 1:K-1]) + + sum_tmp = zero(T) + z = x[1] + lp += log(z) + log(one(T) - z) + @simd for k in 2:K-1 + @inbounds sum_tmp += x[k-1] + @inbounds z = x[k] / (one(T) - sum_tmp) + @inbounds lp += log(z) + log(one(T) - z) + log(one(T) - sum_tmp) end - lp += sum([log(z[k]) + log(one(T) - z[k]) + log(one(T) - sum(x[1:k-1])) for k in 1:K-1]) end lp end