From e8a9a10ece5599c520b9cb74f13a7397cc5f7297 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Sep 2021 00:55:00 +0000 Subject: [PATCH] Sibling PR of introduction of Setfield.jl in AbstractPPL.jl (#295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a sibling PR to https://github.com/TuringLang/AbstractPPL.jl/pull/26 fixing some issues + allowing us to do neat stuff. We also finally drop the passing of the `inds` around in the tilde-pipeline, which is not very useful now that we have the more general lenses in `VarName`. TODOs: - [X] ~Deprecate `*tilde_*` with `inds` argument appropriately.~ EDIT: On second thought, let's not. See comment for reason. - [x] It seems like the prob macro is now somehow broken :confused: - [X] ~(Maybe) Rewrite `@model` to not escape the entire expression.~ Deferred to #311 - [X] Figure out performance degradation. - Answer: `hash` for `Tuple` vs. `hash` for immutable struct :confused: ## Sample fields of structs ```julia julia> @model function demo(x, y) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) for i in 2:length(x.a) - 1 x.a[i] ~ Normal(m, √s) end # Dynamic indexing x.a[begin] ~ Normal(-100.0, 1.0) x.a[end] ~ Normal(100.0, 1.0) # Immutable set y.a ~ Normal() # Dotted z = Vector{Float64}(undef, 3) z[1:2] .~ Normal() z[end:end] .~ Normal() return (; s, m, x, y, z) end julia> struct MyCoolStruct{T} a::T end julia> m = demo(MyCoolStruct([missing, missing]), MyCoolStruct(missing)); julia> m() (s = 3.483799020996254, m = -0.35566330762328, x = MyCoolStruct{Vector{Union{Missing, Float64}}}(Union{Missing, Float64}[-100.75592540694562, 98.61295291877542]), y = MyCoolStruct{Float64}(-2.1107980419121546), z = [-2.2868359094832584, -1.1378866583607443, 1.172250491861777]) ``` ## Sample fields of `DataFrame` ```julia julia> using DataFrames julia> using Setfield: ConstructionBase julia> function ConstructionBase.setproperties(df::DataFrame, patch::NamedTuple) # Only need `copy` because we'll replace entire columns columns = copy(DataFrames._columns(df)) colindex = DataFrames.index(df) for k in keys(patch) columns[colindex[k]] = patch[k] end return DataFrame(columns, colindex) end julia> @model function demo(x) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) for i in 1:length(x.a) - 1 x.a[i] ~ Normal(m, √s) end x.a[end] ~ Normal(100.0, 1.0) return x end demo (generic function with 1 method) julia> m = demo(df, (a = missing, )); julia> m() 3×1 DataFrame Row │ a │ Float64? ─────┼────────── 1 │ 1.0 2 │ 2.0 3 │ 99.8838 julia> df 3×1 DataFrame Row │ a │ Float64? ─────┼─────────── 1 │ 1.0 2 │ 2.0 3 │ missing ``` # Benchmarks Unfortunately there does seem to be performance regression when using a very large number of varnames in a loop in the model (for broadcasting which uses the same number of varnames but does so "internally", there is no difference): ![image](https://user-images.githubusercontent.com/11074788/127791298-da3d0fb2-baab-428b-a555-3f4d2c63bd3b.png) The weird thing is that we're using less memory, indicating that type-inference might better?
0.31.1 ## 0.31.1 ## ### Setup ### ```julia using BenchmarkTools, DynamicPPL, Distributions, Serialization ``` ```julia import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child ``` ### Models ### #### `demo1` #### ```julia @model function demo1(x) m ~ Normal() x ~ Normal(m, 1) return (m = m, x = x) end model_def = demo1; data = 1.0; ``` ```julia @time model_def(data)(); ``` ``` 0.059594 seconds (115.76 k allocations: 6.982 MiB, 99.91% compilation tim e) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (2 allocations: 48 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 619.000 ns … 19.678 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 654.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 677.650 ns ± 333.145 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▅▆▇█▅▄▃ ▃▅███████▇▆▅▄▃▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂ ▃ 619 ns Histogram: frequency by time 945 ns < Memory estimate: 480 bytes, allocs estimate: 13. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 249.000 ns … 11.048 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 264.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 267.650 ns ± 137.452 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▂▄ ▆▇ █▇ ▇▄ ▂▂ ▂▂▂▁▂▂▁▃▃▁▅▅▁███▁██▁██▁██▁██▁▇▇▅▁▄▄▁▃▃▁▃▃▁▃▂▁▂▂▂▁▂▂▁▂▂▁▂▂▁▂▂▂ ▃ 249 ns Histogram: frequency by time 291 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo2` #### ```julia @model function demo2(y) # Our prior belief about the probability of heads in a coin. p ~ Beta(1, 1) # The number of observations. N = length(y) for n in 1:N # Heads or tails of a coin are drawn from a Bernoulli distribution. y[n] ~ Bernoulli(p) end end model_def = demo2; data = rand(0:1, 10); ``` ```julia @time model_def(data)(); ``` ``` 0.067078 seconds (143.91 k allocations: 8.544 MiB, 99.91% compilation tim e) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 1.637 μs … 48.917 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.694 μs ┊ GC (median): 0.00% Time (mean ± σ): 1.746 μs ± 550.372 ns ┊ GC (mean ± σ): 0.00% ± 0.00% ▂█▇▃ ▁▄████▇▄▄▅▅▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 1.64 μs Histogram: frequency by time 2.23 μs < Memory estimate: 1.66 KiB, allocs estimate: 47. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 506.000 ns … 10.733 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 546.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 553.478 ns ± 118.542 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▃█ ▆▅ ▂▃██▇▇██▅▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▂▂▁▁▁▁▁▂▂▁▂▂▁▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▃ 506 ns Histogram: frequency by time 933 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo3` #### ```julia @model function demo3(x) D, N = size(x) # Draw the parameters for cluster 1. μ1 ~ Normal() # Draw the parameters for cluster 2. μ2 ~ Normal() μ = [μ1, μ2] # Comment out this line if you instead want to draw the weights. w = [0.5, 0.5] # Draw assignments for each datum and generate it from a multivariate normal. k = Vector{Int}(undef, N) for i in 1:N k[i] ~ Categorical(w) x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.) end return k end model_def = demo3 # Construct 30 data points for each cluster. N = 30 # Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example. μs = [-3.5, 0.0] # Construct the data points. data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ``` ```julia @time model_def(data)(); ``` ``` 0.097628 seconds (224.06 k allocations: 13.410 MiB, 99.79% compilation ti me) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 48.200 μs … 16.129 ms ┊ GC (min … max): 0.00% … 99.5 3% Time (median): 51.017 μs ┊ GC (median): 0.00% Time (mean ± σ): 60.128 μs ± 265.008 μs ┊ GC (mean ± σ): 7.61% ± 1.7 2% ▂▆█ ████▂▂▂▁▂▃▄▅▇▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂ 48.2 μs Histogram: frequency by time 101 μs < Memory estimate: 48.20 KiB, allocs estimate: 1042. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 22.210 μs … 13.796 ms ┊ GC (min … max): 0.00% … 99.7 0% Time (median): 25.882 μs ┊ GC (median): 0.00% Time (mean ± σ): 27.536 μs ± 137.815 μs ┊ GC (mean ± σ): 5.00% ± 1.0 0% █▇▆▄▂ ▁▇▆▇▆▅▄▂ ▂▂▂▁ ▂ ████████████████████████▆▆▃▅▅▅▅▅▅▁▆▇▆▅▅▅▆▆▅▆▇▇▇▇▆▆▅▆▆▅▅▇█▇█▇ █ 22.2 μs Histogram: log(frequency) by time 51 μs < Memory estimate: 17.62 KiB, allocs estimate: 183. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo4`: loads of indexing #### ```julia @model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) for i in eachindex(x) x[i] ~ Normal(m, 1.0) end end model_def = demo4 data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 0.435154 seconds (3.12 M allocations: 192.275 MiB, 8.73% gc time, 1.84% c ompilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 62 samples with 1 evaluation. Range (min … max): 61.601 ms … 101.432 ms ┊ GC (min … max): 0.00% … 25.0 2% Time (median): 76.902 ms ┊ GC (median): 0.00% Time (mean ± σ): 77.276 ms ± 11.445 ms ┊ GC (mean ± σ): 6.48% ± 10.7 7% ▂ ▂ █ ▆ ▆▆██▄▄▁▄█▄▁▁▁▁▁▁▆▁█▁█▁▄████▁▄▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▄▁▁▆▁▁▄▆▄▁▄▁▆▄▁▄ ▁ 61.6 ms Histogram: frequency by time 101 ms < Memory estimate: 44.37 MiB, allocs estimate: 1357727. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 189 samples with 1 evaluation. Range (min … max): 23.796 ms … 40.845 ms ┊ GC (min … max): 0.00% … 0.00% Time (median): 24.838 ms ┊ GC (median): 0.00% Time (mean ± σ): 25.162 ms ± 1.434 ms ┊ GC (mean ± σ): 0.00% ± 0.00% ▁ ▂▂▃█▂ ▃▂ ▁ ▃▅█▃▇▅█▇██████▇█████▇▄▅▃▅▆▇█▃▆▃▃▄▅▁▄▁▃▁▆▅▄▁▁▁▃▁▃▄▁▁▃▃▁▁▁▁▃▄ ▃ 23.8 ms Histogram: frequency by time 27.8 ms < Memory estimate: 781.70 KiB, allocs estimate: 6. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` ```julia @model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) x .~ Normal(m, 1.0) end model_def = demo4_dotted data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 1.476057 seconds (5.08 M allocations: 375.205 MiB, 5.02% gc time, 0.62% c ompilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 39 samples with 1 evaluation. Range (min … max): 112.078 ms … 350.311 ms ┊ GC (min … max): 11.20% … 4. 74% Time (median): 115.686 ms ┊ GC (median): 12.93% Time (mean ± σ): 122.722 ms ± 37.638 ms ┊ GC (mean ± σ): 12.96% ± 2. 85% █▅ ▁ ██▅█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅ ▁ 112 ms Histogram: log(frequency) by time 350 ms < Memory estimate: 347.71 MiB, allocs estimate: 964550. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 59 samples with 1 evaluation. Range (min … max): 69.420 ms … 407.970 ms ┊ GC (min … max): 12.25% … 6.3 0% Time (median): 71.514 ms ┊ GC (median): 12.41% Time (mean ± σ): 78.481 ms ± 43.867 ms ┊ GC (mean ± σ): 12.80% ± 2.8 4% ▅▂█ █▅ ▇██████▅▅▄▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▄▄▁▁▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄ ▁ 69.4 ms Histogram: frequency by time 94.2 ms < Memory estimate: 337.55 MiB, allocs estimate: 399306. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ```
This PR ## This PR ## ### Setup ### ```julia using BenchmarkTools, DynamicPPL, Distributions, Serialization ``` ```julia import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child ``` ### Models ### #### `demo1` #### ```julia @model function demo1(x) m ~ Normal() x ~ Normal(m, 1) return (m = m, x = x) end model_def = demo1; data = 1.0; ``` ```julia @time model_def(data)(); ``` ``` 1.063017 seconds (2.88 M allocations: 180.745 MiB, 4.19% gc time, 99.90% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (2 allocations: 48 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 615.000 ns … 13.280 ms ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 650.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 2.037 μs ± 132.793 μs ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▅█▇▅▄▄▃▂▁▁ ▁ ███████████▇▇▇▆▆▆▆▃▄▆▆▅▆▇▆▆▇▆▆▇▆▆▆▆▅▆▆▅▅▅▅▄▄▅▅▃▅▅▃▅▄▅▅▅▅▅▄▅▆▅ █ 615 ns Histogram: log(frequency) by time 1.7 μs < Memory estimate: 480 bytes, allocs estimate: 13. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 272.000 ns … 9.093 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 284.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 310.535 ns ± 156.251 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▅█▆▄▃▃▂▁▁ ▁ ███████████▇▇▆▄▄▃▃▄▅▆▅▆▅▆▆▆▆▆▆▆▇▇▆▆▆▆▆▇▆▆▆▆▇▆▇▇▇▇▆▆▆▆▆▅▆▆▅▄▅▅ █ 272 ns Histogram: log(frequency) by time 643 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo2` #### ```julia @model function demo2(y) # Our prior belief about the probability of heads in a coin. p ~ Beta(1, 1) # The number of observations. N = length(y) for n in 1:N # Heads or tails of a coin are drawn from a Bernoulli distribution. y[n] ~ Bernoulli(p) end end model_def = demo2; data = rand(0:1, 10); ``` ```julia @time model_def(data)(); ``` ``` 0.401535 seconds (863.20 k allocations: 51.771 MiB, 2.88% gc time, 99.90% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000003 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 1.672 μs … 9.849 ms ┊ GC (min … max): 0.00% … 0.00% Time (median): 1.754 μs ┊ GC (median): 0.00% Time (mean ± σ): 2.835 μs ± 98.472 μs ┊ GC (mean ± σ): 0.00% ± 0.00% ▅██▇▆▆▅▄▄▃▂▂▁▁ ▁▁▁ ▁ ▂ ██████████████████▇▇▇▇▇▆▇▆▅▆▄▄▁▄▄▄▆▇██████████▆▆▇▇▇▇▆▇▆▆▆▆ █ 1.67 μs Histogram: log(frequency) by time 3.19 μs < Memory estimate: 1.50 KiB, allocs estimate: 37. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 544.000 ns … 19.704 μs ┊ GC (min … max): 0.00% … 0.0 0% Time (median): 567.000 ns ┊ GC (median): 0.00% Time (mean ± σ): 578.671 ns ± 222.201 ns ┊ GC (mean ± σ): 0.00% ± 0.0 0% ▄█▇▅▂▃ ▃███████▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▂▁▁▁▁▁▂▁▂▁▂▁▁▁▁▁▂▂▂▂▂▂▂▁▂▂▂▂▂ ▃ 544 ns Histogram: frequency by time 888 ns < Memory estimate: 0 bytes, allocs estimate: 0. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo3` #### ```julia @model function demo3(x) D, N = size(x) # Draw the parameters for cluster 1. μ1 ~ Normal() # Draw the parameters for cluster 2. μ2 ~ Normal() μ = [μ1, μ2] # Comment out this line if you instead want to draw the weights. w = [0.5, 0.5] # Draw assignments for each datum and generate it from a multivariate normal. k = Vector{Int}(undef, N) for i in 1:N k[i] ~ Categorical(w) x[:,i] ~ MvNormal([μ[k[i]], μ[k[i]]], 1.) end return k end model_def = demo3 # Construct 30 data points for each cluster. N = 30 # Parameters for each cluster, we assume that each cluster is Gaussian distributed in the example. μs = [-3.5, 0.0] # Construct the data points. data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ``` ```julia @time model_def(data)(); ``` ``` 1.031824 seconds (2.34 M allocations: 139.934 MiB, 3.16% gc time, 99.96% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (1 allocation: 32 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 52.509 μs … 9.913 ms ┊ GC (min … max): 0.00% … 0.00 % Time (median): 53.706 μs ┊ GC (median): 0.00% Time (mean ± σ): 61.948 μs ± 210.490 μs ┊ GC (mean ± σ): 9.84% ± 3.27 % ▂▆██▇▆▅▄▄▃▃▂▂▁▁▁▁ ▁ ▂ █████████████████████▇█▇▇▇█████████▇▇▇▇▅▆▆▅▆▅▅▆▇▅▅▄▅▅▄▄▄▄▂▄▃ █ 52.5 μs Histogram: log(frequency) by time 71.3 μs < Memory estimate: 47.66 KiB, allocs estimate: 1007. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 10000 samples with 1 evaluation. Range (min … max): 25.046 μs … 7.474 ms ┊ GC (min … max): 0.00% … 99.4 0% Time (median): 25.591 μs ┊ GC (median): 0.00% Time (mean ± σ): 29.101 μs ± 105.160 μs ┊ GC (mean ± σ): 6.84% ± 1.9 8% ▇█▆▄▃▂▂▁ ▃▄▂▃▃▂▂▁ ▂ █████████▇█████████▇▆▅▆▆▇▇▇▇▅▆▆▅▃▅▄▃▂▂▃▂▄▄▅▄▄▅▄▅▅▅▆▆▅▅▅▅▆▇▇█ █ 25 μs Histogram: log(frequency) by time 46 μs < Memory estimate: 17.62 KiB, allocs estimate: 183. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` #### `demo4`: lots of univariate random variables #### ```julia @model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) for i in eachindex(x) x[i] ~ Normal(m, 1.0) end end model_def = demo4 data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 0.835503 seconds (3.93 M allocations: 244.654 MiB, 10.38% gc time, 9.43% compilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000004 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 60 samples with 1 evaluation. Range (min … max): 68.149 ms … 104.358 ms ┊ GC (min … max): 0.00% … 0.00 % Time (median): 77.456 ms ┊ GC (median): 0.00% Time (mean ± σ): 80.173 ms ± 9.858 ms ┊ GC (mean ± σ): 6.67% ± 8.31 % ▆█ █▄ ▂▄ █▆██▁▁▄▁▁▁▁▁▁▁▁▁▁▆▆▁██▆▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆████▄▁▄▁▄ ▁ 68.1 ms Histogram: frequency by time 94.8 ms < Memory estimate: 42.78 MiB, allocs estimate: 1253404. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 145 samples with 1 evaluation. Range (min … max): 29.232 ms … 139.283 ms ┊ GC (min … max): 0.00% … 0.00 % Time (median): 30.997 ms ┊ GC (median): 0.00% Time (mean ± σ): 32.506 ms ± 9.228 ms ┊ GC (mean ± σ): 0.23% ± 1.93 % ▁▆█▇▆▃▁ ▃▆███████▅▄▃▃▃▃▃▃▁▅▅▄▃▁▁▃▁▃▃▃▁▁▁▃▃▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▃ 29.2 ms Histogram: frequency by time 46.4 ms < Memory estimate: 781.86 KiB, allocs estimate: 7. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ``` ```julia @model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} m ~ Normal() x = TV(undef, n) x .~ Normal(m, 1.0) end model_def = demo4_dotted data = (100_000, ); ``` ```julia @time model_def(data)(); ``` ``` 1.421197 seconds (5.08 M allocations: 375.131 MiB, 6.23% gc time, 0.62% c ompilation time) ``` ```julia m = time_model_def(model_def, data); ``` ``` 0.000002 seconds (2 allocations: 64 bytes) ``` ```julia suite = make_suite(m); results = run(suite); ``` ```julia results["evaluation_untyped"] ``` ``` BenchmarkTools.Trial: 39 samples with 1 evaluation. Range (min … max): 108.605 ms … 348.289 ms ┊ GC (min … max): 9.70% … 9. 23% Time (median): 118.470 ms ┊ GC (median): 15.38% Time (mean ± σ): 121.407 ms ± 37.585 ms ┊ GC (mean ± σ): 13.35% ± 3. 15% ▆ █ █▁█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁ 109 ms Histogram: frequency by time 348 ms < Memory estimate: 347.69 MiB, allocs estimate: 963583. ``` ```julia results["evaluation_typed"] ``` ``` BenchmarkTools.Trial: 61 samples with 1 evaluation. Range (min … max): 66.380 ms … 350.632 ms ┊ GC (min … max): 9.01% … 4.7 7% Time (median): 73.635 ms ┊ GC (median): 16.29% Time (mean ± σ): 75.751 ms ± 35.996 ms ┊ GC (mean ± σ): 12.78% ± 3.8 9% █ ▄ ▃ ▇█▆▆▄▄▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▆████▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁ 66.4 ms Histogram: frequency by time 84.5 ms < Memory estimate: 337.55 MiB, allocs estimate: 399306. ``` ```julia if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end ```
--- Project.toml | 6 +- benchmarks/benchmark_body.jmd | 11 ++- benchmarks/benchmarks.jmd | 34 ++++++++ src/DynamicPPL.jl | 3 + src/compiler.jl | 120 +++++++++++++++----------- src/context_implementations.jl | 153 +++++++++++++-------------------- src/contexts.jl | 6 +- src/loglikelihoods.jl | 4 +- test/Project.toml | 4 +- test/compiler.jl | 50 +++++++++++ test/contexts.jl | 21 +++-- test/turing/Project.toml | 2 +- test/turing/varinfo.jl | 27 ------ 13 files changed, 252 insertions(+), 189 deletions(-) diff --git a/Project.toml b/Project.toml index f40c46c0e..9d26a90ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.15.2" +version = "0.16.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -12,16 +12,18 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" -AbstractPPL = "0.2" +AbstractPPL = "0.3" BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10, 1" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" +Setfield = "0.7.1" ZygoteRules = "0.2" julia = "1.3" diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd index f9c994dc9..9d9810dc2 100644 --- a/benchmarks/benchmark_body.jmd +++ b/benchmarks/benchmark_body.jmd @@ -8,8 +8,15 @@ m = time_model_def(model_def, data); ```julia suite = make_suite(m); -results = run(suite) -results +results = run(suite); +``` + +```julia +results["evaluation_untyped"] +``` + +```julia +results["evaluation_typed"] ``` ```julia; echo=false; results="hidden"; diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd index 614afb2e9..5b86b261e 100644 --- a/benchmarks/benchmarks.jmd +++ b/benchmarks/benchmarks.jmd @@ -94,3 +94,37 @@ data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ```julia; echo=false weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) ``` + +### `demo4`: loads of indexing + +```julia +@model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, n) + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end +end + +model_def = demo4 +data = (100_000, ); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +```julia +@model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, n) + x .~ Normal(m, 1.0) +end + +model_def = demo4_dotted +data = (100_000, ); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1238add26..15ee22325 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -11,6 +11,9 @@ using MacroTools: MacroTools using ZygoteRules: ZygoteRules using BangBang: BangBang +using Setfield: Setfield +using BangBang: BangBang + using Random: Random import Base: diff --git a/src/compiler.jl b/src/compiler.jl index 5c2369f6c..8ad248622 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -18,7 +18,7 @@ function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote - let $vn = $(varname(expr)) + let $vn = $(AbstractPPL.drop_escape(varname(expr))) if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, @@ -133,17 +133,17 @@ variables. # Example ```jldoctest; setup=:(using Distributions, LinearAlgebra) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string(vns[end]) -"x[:,2]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] +x[:,2] -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end]) -"x[:][1,2]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] +x[1,2] -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); string(vns[end]) -"x[1][3]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end] +x[:][1,2] -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); string(vns[end]) -"x[1,2,3]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end] +x[1][3] ``` """ unwrap_right_left_vns(right, left, vns) = right, left, vns @@ -158,7 +158,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return VarName(vn, (vn.indexing..., (Colon(), i))) + return vn ∘ Setfield.IndexLens((Colon(), i)) end return unwrap_right_left_vns(right, left, vns) end @@ -168,7 +168,7 @@ function unwrap_right_left_vns( vn::VarName, ) vns = map(CartesianIndices(left)) do i - return VarName(vn, (vn.indexing..., Tuple(i))) + return vn ∘ Setfield.IndexLens(Tuple(i)) end return unwrap_right_left_vns(right, left, vns) end @@ -317,6 +317,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Do we don't want escaped expressions because we unfortunately + # escape the entire body afterwards. + Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) @@ -349,6 +353,15 @@ function generate_mainbody!(mod, found, expr::Expr, warn) return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) end +function generate_tilde_literal(left, right) + # If the LHS is a literal, it is always an observation + return quote + $(DynamicPPL.tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + ) + end +end + """ generate_tilde(left, right) @@ -356,31 +369,20 @@ Generate an `observe` expression for data variables and `assume` expression for variables. """ function generate_tilde(left, right) - # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn isassumption + + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. return quote - $vn = $(varname(left)) - $inds = $(vinds(left)) + $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., - $inds, - __varinfo__, - ) + $(generate_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) @@ -392,44 +394,46 @@ function generate_tilde(left, right) $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - $inds, __varinfo__, ) end end end +function generate_tilde_assume(left, right, vn) + expr = :( + $left = $(DynamicPPL.tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., + __varinfo__, + ) + ) + + return if left isa Expr + AbstractPPL.drop_escape( + Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + ) + else + return expr + end +end + """ generate_dot_tilde(left, right) Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn isassumption return quote - $vn = $(varname(left)) - $inds = $(vinds(left)) + $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - $inds, - __varinfo__, - ) + $(generate_dot_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) @@ -441,13 +445,27 @@ function generate_dot_tilde(left, right) $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - $inds, __varinfo__, ) end end end +function generate_dot_tilde_assume(left, right, vn) + # We don't need to use `Setfield.@set` here since + # `.=` is always going to be inplace + needs `left` to + # be something that supports `.=`. + return :( + $left .= $(DynamicPPL.dot_tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn + )..., + __varinfo__, + ) + ) +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 498e83492..19b5ce061 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,21 +14,9 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) -_getindex(x, inds::Tuple{}) = x -_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) -function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} - val = getproperty(x, sym) - - # This should work with both cartesian and linear indexing. - return map(vns) do vn - _getindex(val, vn) - end -end - # assume """ - tilde_assume(context::SamplingContext, right, vn, inds, vi) + tilde_assume(context::SamplingContext, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value with a context associated @@ -36,18 +24,18 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) ``` """ -function tilde_assume(context::SamplingContext, right, vn, inds, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +function tilde_assume(context::SamplingContext, right, vn, vi) + return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end # Leaf contexts function tilde_assume(context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), context, args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi) +function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) return assume(right, vn, vi) end function tilde_assume(::IsParent, context::AbstractContext, args...) @@ -57,44 +45,36 @@ end function tilde_assume(rng, context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) end -function tilde_assume( - ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi -) +function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end function tilde_assume(::IsParent, rng, context::AbstractContext, args...) return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(PriorContext(), right, vn, inds, vi) + return tilde_assume(PriorContext(), right, vn, vi) end function tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, + rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) + return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(LikelihoodContext(), right, vn, inds, vi) + return tilde_assume(LikelihoodContext(), right, vn, vi) end function tilde_assume( rng::Random.AbstractRNG, @@ -102,42 +82,39 @@ function tilde_assume( sampler, right, vn, - inds, vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end -function tilde_assume(::LikelihoodContext, right, vn, inds, vi) +function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi -) +function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +function tilde_assume(context::PrefixContext, right, vn, vi) + return tilde_assume(context.context, right, prefix(context, vn), vi) end -function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) end """ - tilde_assume!(context, right, vn, inds, vi) + tilde_assume!(context, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, inds, vi) - value, logp = tilde_assume(context, right, vn, inds, vi) +function tilde_assume!(context, right, vn, vi) + value, logp = tilde_assume(context, right, vn, vi) acclogp!(vi, logp) return value end @@ -180,7 +157,7 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. @@ -188,7 +165,7 @@ accumulate the log probability, and return the observed value. Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vinds, vi) +function tilde_observe!(context, right, left, vname, vi) return tilde_observe!(context, right, left, vi) end @@ -260,7 +237,7 @@ end # assume """ - dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + dot_tilde_assume(context::SamplingContext, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value for a context @@ -268,12 +245,12 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) ``` """ -function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, inds, vi + context.rng, context.context, context.sampler, right, left, vn, vi ) end @@ -285,12 +262,10 @@ function dot_tilde_assume(rng, context::AbstractContext, args...) return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) end -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) +function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume( - ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi -) +function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) return dot_assume(rng, sampler, right, vns, left, vi) end @@ -301,22 +276,20 @@ function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...) return dot_tilde_assume(rng, childcontext(context), args...) end -function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, vi) return dot_assume(rng, sampler, right, vns, left, vi) end # `LikelihoodContext` -function dot_tilde_assume( - context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi -) +function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else - dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) end end function dot_tilde_assume( @@ -326,38 +299,37 @@ function dot_tilde_assume( right, left, vn, - inds, vi, ) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) end end -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) end # `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else - dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + dot_tilde_assume(PriorContext(), right, left, vn, vi) end end function dot_tilde_assume( @@ -367,41 +339,40 @@ function dot_tilde_assume( right, left, vn, - inds, vi, ) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) end end # `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) +function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( - rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + rng, context.context, sampler, right, prefix.(Ref(context), vn), vi ) end """ - dot_tilde_assume!(context, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, vi) + value, logp = dot_tilde_assume(context, right, left, vn, vi) acclogp!(vi, logp) return value end @@ -598,7 +569,7 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. @@ -606,7 +577,7 @@ accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, inds, vi) +function dot_tilde_observe!(context, right, left, vn, vi) return dot_tilde_observe!(context, right, left, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index 98eb4b85d..03fc26245 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -246,9 +246,9 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn))) else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn)) end end @@ -311,7 +311,7 @@ Return value of `vn` in `context`. function getvalue(context::AbstractContext, vn) return error("context $(context) does not contain value for $vn") end -getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) +getvalue(context::ConditionContext, vn) = get(context.values, vn) """ hasvalue_nested(context, vn) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 4e221fd57..cd50811c1 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -71,7 +71,7 @@ function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!(context.context, right, left, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp = tilde_observe(context.context, right, left, vi) @@ -87,7 +87,7 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi # Defer literal `observe` to child-context. return dot_tilde_observe!(context.context, right, left, vi) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. diff --git a/test/Project.toml b/test/Project.toml index 948d8e5af..3af6ef22d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -20,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0" -AbstractPPL = "0.2" +AbstractPPL = "0.3" Bijectors = "0.9.5" Distributions = "0.25" DistributionsAD = "0.6.3" @@ -28,6 +29,7 @@ Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" MCMCChains = "4.0.4, 5" MacroTools = "0.5.5" +Setfield = "0.7.1" StableRNGs = "1" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" diff --git a/test/compiler.jl b/test/compiler.jl index 0f072b468..3140cf5b7 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -28,6 +28,11 @@ macro mymodel2(ex) end end +# Used to test sampling of immutable types. +struct MyCoolStruct{T} + a::T +end + @testset "compiler.jl" begin @testset "model macro" begin @model function testmodel_comp(x, y) @@ -235,6 +240,51 @@ end @test haskey(vi.metadata, :x) vi = VarInfo(gdemo(x)) @test haskey(vi.metadata, :x) + + # Non-array variables + @model function testmodel_nonarray(x, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in 2:(length(x.a) - 1) + x.a[i] ~ Normal(m, √s) + end + + # Dynamic indexing + x.a[end] ~ Normal(100.0, 1.0) + + # Immutable set + y.a ~ Normal() + + # Dotted + z = Vector{Float64}(undef, 3) + z[1:2] .~ Normal() + z[end:end] .~ Normal() + + return (; s=s, m=m, x=x, y=y, z=z) + end + + m_nonarray = testmodel_nonarray( + MyCoolStruct([missing, missing]), MyCoolStruct(missing) + ) + result = m_nonarray() + @test !any(ismissing, result.x.a) + @test result.y.a !== missing + @test result.x.a[end] > 10 + + # Ensure that we can work with `Vector{Real}(undef, N)` which is the + # reason why we're using `BangBang.prefermutation` in `src/compiler.jl` + # rather than the default from Setfield.jl. + # Related: https://github.com/jw3126/Setfield.jl/issues/157 + @model function vdemo() + x = Vector{Real}(undef, 10) + for i in eachindex(x) + x[i] ~ Normal(0, sqrt(4)) + end + + return x + end + x = vdemo()() + @test all((isassigned(x, i) for i in eachindex(x))) end @testset "nested model" begin function makemodel(p) diff --git a/test/contexts.jl b/test/contexts.jl index c63535cb3..edf581d4d 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL +using Test, DynamicPPL, Setfield using DynamicPPL: leafcontext, setleafcontext, @@ -53,7 +53,7 @@ Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - vn.indexing + getlens(vn) ) end @@ -65,11 +65,14 @@ e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1 """ varnames(vn::VarName, val::Real) = [vn] function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) + return ( + VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) end function varnames(vn::VarName, val::AbstractArray) return Iterators.flatten( - varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for I in CartesianIndices(val) ) end @@ -183,7 +186,7 @@ end # Let's check elementwise. for vn_child in varnames(vn_without_prefix, val) - if DynamicPPL._getindex(val, vn_child.indexing) === missing + if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @test !contextual_isassumption(context, vn_child) @@ -219,7 +222,7 @@ end @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === - DynamicPPL._getindex(val, vn_child.indexing) + get(val, getlens(vn_child)) end end end @@ -246,11 +249,11 @@ end vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test vn_prefixed.indexing === vn.indexing + @test getlens(vn_prefixed) === getlens(vn) - vn = VarName{:x}((1,)) + vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test vn_prefixed.indexing === vn.indexing + @test getlens(vn_prefixed) === getlens(vn) end end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index fe186816f..9d75e2dcb 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.15" +DynamicPPL = "0.16" Turing = "0.18" julia = "1.3" diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index cc8b61d04..892433779 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -184,29 +184,6 @@ chain = sample(priorsinarray(xs), HMC(0.01, 10), 10) end @testset "varname" begin - i, j, k = 1, 2, 3 - - vn1 = @varname x[1] - @test vn1 == VarName{:x}(((1,),)) - - # Symbol - v_sym = string(:x) - @test v_sym == "x" - - # Array - v_arr = @varname x[i] - @test v_arr.indexing == ((1,),) - - # Matrix - v_mat = @varname x[i, j] - @test v_mat.indexing == ((1, 2),) - - v_mat = @varname x[i, j, k] - @test v_mat.indexing == ((1, 2, 3),) - - v_mat = @varname x[1, 2][1 + 5][45][3][i] - @test v_mat.indexing == ((1, 2), (6,), (45,), (3,), (1,)) - @model function mat_name_test() p = Array{Any}(undef, 2, 2) for i in 1:2, j in 1:2 @@ -217,10 +194,6 @@ chain = sample(mat_name_test(), HMC(0.2, 4), 1000) check_numerical(chain, ["p[1,1]"], [0]; atol=0.25) - # Multi array - v_arrarr = @varname x[i][j] - @test v_arrarr.indexing == ((1,), (2,)) - @model function marr_name_test() p = Array{Array{Any}}(undef, 2) p[1] = Array{Any}(undef, 2)