Skip to content

Commit

Permalink
Update for DynamicPPL 0.33 and 0.34 (#2459)
Browse files Browse the repository at this point in the history
* Update for DynamicPPL 0.33

* Don't remove import/export

* 0.34 too

* Update test compat too

* Remove upstream tests for `predict`
  • Loading branch information
penelopeysm authored Jan 23, 2025
1 parent 24d5556 commit 8bf98e1
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 246 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.36.0"
version = "0.36.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.32"
DynamicPPL = "0.33, 0.34"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Libtask = "0.8.8"
Expand Down
108 changes: 1 addition & 107 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# this means that the code below will work both of linked and invlinked `vi`.
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))
vals = DynamicPPL.values_as_in_model(model, true, deepcopy(vi))

# Obtain an iterator over the flattened parameter names and values.
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
Expand Down Expand Up @@ -612,112 +612,6 @@ end
DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg)
DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
Execute `model` conditioned on each sample in `chain`, and return the resulting `Chains`.
If `include_all` is `false`, the returned `Chains` will contain only those variables
sampled/not present in `chain`.
# Details
Internally calls `Turing.Inference.transitions_from_chain` to obtained the samples
and then converts these into a `Chains` object using `AbstractMCMC.bundle_samples`.
# Example
```jldoctest
julia> using Turing; Turing.setprogress!(false);
[ Info: [Turing]: progress logging is disabled globally
julia> @model function linear_reg(x, y, σ = 0.1)
β ~ Normal(0, 1)
for i ∈ eachindex(y)
y[i] ~ Normal(β * x[i], σ)
end
end;
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn();
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train);
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test);
julia> m_train = linear_reg(xs_train, ys_train, σ);
julia> chain_lin_reg = sample(m_train, NUTS(100, 0.65), 200);
┌ Info: Found initial step size
└ ϵ = 0.003125
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ);
julia> predictions = predict(m_test, chain_lin_reg)
Object of type Chains, with data of type 100×2×1 Array{Float64,3}
Iterations = 1:100
Thinning interval = 1
Chains = 1
Samples per chain = 100
parameters = y[1], y[2]
2-element Array{ChainDataFrame,1}
Summary Statistics
parameters mean std naive_se mcse ess r_hat
────────── ─────── ────── ──────── ─────── ──────── ──────
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
────────── ─────── ─────── ─────── ─────── ───────
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1));
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1
true
```
"""
function predict(model::Model, chain::MCMCChains.Chains; kwargs...)
return predict(Random.default_rng(), model, chain; kwargs...)
end
function predict(
rng::AbstractRNG, model::Model, chain::MCMCChains.Chains; include_all=false
)
# Don't need all the diagnostics
chain_parameters = MCMCChains.get_sections(chain, :parameters)

spl = DynamicPPL.SampleFromPrior()

# Sample transitions using `spl` conditioned on values in `chain`
transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl)

# Let the Turing internals handle everything else for you
chain_result = reduce(
MCMCChains.chainscat,
[
AbstractMCMC.bundle_samples(
transitions[:, chain_idx], model, spl, nothing, MCMCChains.Chains
) for chain_idx in 1:size(transitions, 2)
],
)

parameter_names = if include_all
names(chain_result, :parameters)
else
filter(
k -> (k, names(chain_parameters, :parameters)),
names(chain_result, :parameters),
)
end

return chain_result[parameter_names]
end

"""
transitions_from_chain(
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Combinatorics = "1"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.32.2"
DynamicPPL = "0.33, 0.34"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
Expand Down
136 changes: 0 additions & 136 deletions test/mcmc/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,145 +1,9 @@
module MCMCUtilitiesTests

using ..Models: gdemo_default
using Distributions: Normal, sample, truncated
using LinearAlgebra: I, vec
using Random: Random
using Random: MersenneTwister
using Test: @test, @testset
using Turing

@testset "predict" begin
Random.seed!(100)

@model function linear_reg(x, y, σ=0.1)
β ~ Normal(0, 1)

for i in eachindex(y)
y[i] ~ Normal* x[i], σ)
end
end

@model function linear_reg_vec(x, y, σ=0.1)
β ~ Normal(0, 1)
return y ~ MvNormal.* x, σ^2 * I)
end

f(x) = 2 * x + 0.1 * randn()

Δ = 0.1
xs_train = 0:Δ:10
ys_train = f.(xs_train)
xs_test = [10 + Δ, 10 + 2 * Δ]
ys_test = f.(xs_test)

# Infer
m_lin_reg = linear_reg(xs_train, ys_train)
chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), 200)

# Predict on two last indices
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)

ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))

@test sum(abs2, ys_test - ys_pred) 0.1

# Ensure that `rng` is respected
predictions1 = let rng = MersenneTwister(42)
predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
predictions2 = let rng = MersenneTwister(42)
predict(rng, m_lin_reg_test, chain_lin_reg[1:2])
end
@test all(Array(predictions1) .== Array(predictions2))

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1))

@test sum(abs2, ys_test - ys_pred_vec) 0.1

# Multiple chains
chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), MCMCThreads(), 200, 2)
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test)))
predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)

@test size(chain_lin_reg, 3) == size(predictions, 3)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred) 0.1
end

# Predict on two last indices for vectorized
m_lin_reg_test = linear_reg_vec(xs_test, missing)
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)

for chain_idx in MCMCChains.chains(chain_lin_reg)
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1))
@test sum(abs2, ys_test - ys_pred_vec) 0.1
end

# https://github.com/TuringLang/Turing.jl/issues/1352
@model function simple_linear1(x, y)
intercept ~ Normal(0, 1)
coef ~ MvNormal(zeros(2), I)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear2(x, y)
intercept ~ Normal(0, 1)
coef ~ filldist(Normal(0, 1), 2)
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear3(x, y)
intercept ~ Normal(0, 1)
coef = Vector(undef, 2)
for i in axes(coef, 1)
coef[i] ~ Normal(0, 1)
end
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

@model function simple_linear4(x, y)
intercept ~ Normal(0, 1)
coef1 ~ Normal(0, 1)
coef2 ~ Normal(0, 1)
coef = [coef1, coef2]
coef = reshape(coef, 1, size(x, 1))

mu = vec(intercept .+ coef * x)
error ~ truncated(Normal(0, 1), 0, Inf)
return y ~ MvNormal(mu, error^2 * I)
end

# Some data
x = randn(2, 100)
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)]

for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4]
m = model(x, y)
chain = sample(m, NUTS(), 100)
chain_predict = predict(model(x, missing), chain)
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)]
@test mean(abs2, mean_prediction - y) 1e-3
end
end

@testset "Timer" begin
chain = sample(gdemo_default, MH(), 1000)

Expand Down

2 comments on commit 8bf98e1

@penelopeysm
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • This is a release for compatibility with DynamicPPL 0.33 and 0.34.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123557

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.36.1 -m "<description of version>" 8bf98e1f29deb9e115f53ba8dccef590e3e24ce0
git push origin v0.36.1

Please sign in to comment.