From 28bf738139dd4dc5bca69671f699912aa978d5a8 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 5 Jan 2024 11:46:00 +0100 Subject: [PATCH] Fix `eachvariate` with zero variates (#1819) --- Project.toml | 2 +- src/eachvariate.jl | 2 +- test/eachvariate.jl | 27 +++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 4e71ee9135..60ae8b2618 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.104" +version = "0.25.105" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/eachvariate.jl b/src/eachvariate.jl index 701be99faa..583600f36d 100644 --- a/src/eachvariate.jl +++ b/src/eachvariate.jl @@ -7,7 +7,7 @@ end function EachVariate{V}(x::AbstractArray{<:Real,M}) where {V,M} ax = ntuple(i -> axes(x, i + V), Val(M - V)) - T = typeof(view(x, ntuple(i -> i <= V ? Colon() : firstindex(x, i), Val(M))...)) + T = Base.promote_op(view, typeof(x), ntuple(i -> i <= V ? Colon : eltype(axes(x, i)), Val(M))...) return EachVariate{V,typeof(x),typeof(ax),T,M-V}(x, ax) end diff --git a/test/eachvariate.jl b/test/eachvariate.jl index f41a5207d2..60cd1f68d3 100644 --- a/test/eachvariate.jl +++ b/test/eachvariate.jl @@ -1,12 +1,24 @@ +using Distributions using ChainRulesTestUtils using ChainRulesTestUtils: FiniteDifferences +using Random +using Test + # Without this, `to_vec` will also include the `axes` field of `EachVariate`. function FiniteDifferences.to_vec(xs::Distributions.EachVariate{V}) where {V} vals, vals_from_vec = FiniteDifferences.to_vec(xs.parent) return vals, x -> Distributions.EachVariate{V}(vals_from_vec(x)) end +# MWE in #1817 +struct FooEachvariate <: Sampleable{Multivariate, Continuous} end +Base.length(::FooEachvariate) = 3 +Base.eltype(::FooEachvariate) = Float64 +function Distributions._rand!(rng::AbstractRNG, ::FooEachvariate, x::AbstractVector{<:Real}) + return rand!(rng, x) +end + @testset "eachvariate.jl" begin @testset "ChainRules" begin xs = randn(2, 3, 4, 5) @@ -14,4 +26,19 @@ end test_rrule(Distributions.EachVariate{2}, xs) test_rrule(Distributions.EachVariate{3}, xs) end + + @testset "No variates (#1817)" begin + @test size(Distributions.eachvariate(rand(0), Univariate)) == (0,) + @test size(Distributions.eachvariate(rand(3, 0, 1), Univariate)) == (3, 0, 1) + @test size(Distributions.eachvariate(rand(3, 2, 0), Univariate)) == (3, 2, 0) + + @test size(Distributions.eachvariate(rand(4, 0), Multivariate)) == (0,) + @test size(Distributions.eachvariate(rand(4, 0, 2), Multivariate)) == (0, 2) + @test size(Distributions.eachvariate(rand(4, 5, 0), Multivariate)) == (5, 0) + @test size(Distributions.eachvariate(rand(4, 5, 0, 2), Multivariate)) == (5, 0, 2) + + draws = @inferred(rand(FooEachvariate(), 0)) + @test draws isa Matrix{Float64} + @test size(draws) == (3, 0) + end end