diff --git a/README.md b/README.md index e255fe88..7fdc4424 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The following table lists mathematical operations for a bijector and the corresp | `x ↦ b(x)` | `b(x)` | × | | `y ↦ b⁻¹(y)` | `inv(b)(y)` | × | | `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD | -| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` | ✓ | +| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ | | `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ | | `y ∼ q` | `y = rand(q)` | ✓ | | `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ | @@ -221,18 +221,18 @@ true which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use: ```julia -julia> forward(b, x) -(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655) +julia> with_logabsdet_jacobian(b, x) +(-0.5369949942509267, 1.4575353795716655) ``` Similarily ```julia julia> forward(inv(b), y) -(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655) +(0.3688868996596376, -1.4575353795716655) ``` -In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented). +In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented). #### Sampling from `TransformedDistribution` At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`: @@ -481,7 +481,7 @@ julia> Flux.params(flow) Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)]) ``` -Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path. +Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path. ```julia julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass @@ -555,28 +555,29 @@ Tracked 2-element Array{Float64,1}: -1.546158373866469 -1.6098711387913573 -julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` -(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458) +julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))` +(0.4054651081081642, 1.4271163556401458) ``` -For further efficiency, one could manually implement `forward(b::Logit, x)`: +For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`: ```julia julia> import Bijectors: forward, Logit +julia> import ChangesOfVariables: with_logabsdet_jacobian -julia> function forward(b::Logit{<:Real}, x) +julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x) totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not y = logit.(totally_worth_saving) logjac = @. - log((b.b - x) * totally_worth_saving) - return (rv=y, logabsdetjac = logjac) + return (y, logjac) end forward (generic function with 16 methods) -julia> forward(b, 0.6) -(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458) +julia> with_logabsdet_jacobian(b, 0.6) +(0.4054651081081642, 1.4271163556401458) -julia> @which forward(b, 0.6) -forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2 +julia> @which with_logabsdet_jacobian(b, 0.6) +with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2 ``` As you can see it's a very contrived example, but you get the idea. @@ -715,7 +716,7 @@ The following methods are implemented by all subtypes of `Bijector`, this also i - `(b::Bijector)(x)`: implements the transform of the `Bijector` - `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. - `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))). -- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner. +- `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. - `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation. - `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency. - `dimension(b::Bijector)`: returns the dimensionality of `b`. diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 36779fd4..2393283d 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,8 +35,9 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +import ChangesOfVariables: with_logabsdet_jacobian + import ChainRulesCore -import ChangesOfVariables import Functors import InverseFunctions import IrrationalConstants @@ -251,6 +252,8 @@ include("utils.jl") include("interface.jl") include("chainrules.jl") +Base.@deprecate forward(b::AbstractBijector, x) with_logabsdet_jacobian(b, x) + # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index f05b819c..6a02432e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -179,8 +179,8 @@ function logabsdetjac(cb::Composed, x) y, logjac = forward(cb.ts[1], x) for i = 2:length(cb.ts) res = forward(cb.ts[i], y) - y = res.rv - logjac += res.logabsdetjac + y = res[1] + logjac += res[2] end return logjac @@ -195,8 +195,8 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end # don't need to evaluate the last bijector, only it's `logabsdetjac` push!(expr.args, :(logjac += logabsdetjac(cb.ts[$N], y))) @@ -212,10 +212,10 @@ function forward(cb::Composed, x) for t in cb.ts[2:end] res = forward(t, rv) - rv = res.rv - logjac = res.logabsdetjac + logjac + rv = res[1] + logjac = res[2] + logjac end - return (rv=rv, logabsdetjac=logjac) + return (rv, logjac) end @@ -225,10 +225,10 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (y, logjac))) return expr end diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 65060c14..62413316 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -44,21 +44,21 @@ end logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) -# We implement `forward` by hand since we can re-use the computation of +# We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 0}, x::Real) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) - return (rv=J * x, logabsdetjac=log(abs(J))) + return (J * x, log(abs(J))) end # Batched version -function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - return (rv=J .* x, logabsdetjac=log.(abs.(J))) + return (J .* x, log.(abs.(J))) end # (N=1) Multivariate case @@ -84,7 +84,7 @@ end # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o @@ -97,5 +97,5 @@ function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) end y = J .* x - return (rv=y, logabsdetjac=logjac) + return (y, logjac) end diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 36f8691c..0543f91a 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ abstract type AbstractNamedBijector <: AbstractBijector end -forward(b::AbstractNamedBijector, x) = (rv = b(x), logabsdetjac = logabsdetjac(b, x)) +with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) ####################### ### `NamedBijector` ### @@ -125,8 +125,8 @@ function logabsdetjac(cb::NamedComposition, x) y, logjac = forward(cb.bs[1], x) for i = 2:length(cb.bs) res = forward(cb.bs[i], y) - y = res.rv - logjac += res.logabsdetjac + y = res[1] + logjac += res[2] end return logjac @@ -141,8 +141,8 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end # don't need to evaluate the last bijector, only it's `logabsdetjac` push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y))) @@ -158,10 +158,10 @@ function forward(cb::NamedComposition, x) for t in cb.bs[2:end] res = forward(t, rv) - rv = res.rv - logjac = res.logabsdetjac + logjac + rv = res[1] + logjac = res[2] + logjac end - return (rv=rv, logabsdetjac=logjac) + return (rv, logjac) end @@ -171,10 +171,10 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (y, logjac))) return expr end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 81496468..defb447e 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -48,7 +48,7 @@ function Functors.functor(::Type{<:InvertibleBatchNorm}, x) return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm end -function forward(bn::InvertibleBatchNorm, x) +function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) size(x, dims - 1) == length(bn.b) || error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))") @@ -76,12 +76,12 @@ function forward(bn::InvertibleBatchNorm, x) logabsdetjac = ( fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) ) - return (rv=rv, logabsdetjac=logabsdetjac) + return (rv, logabsdetjac) end -logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac +logabsdetjac(bn::InvertibleBatchNorm, x) = with_logabsdet_jacobian(bn, x)[2] -(bn::InvertibleBatchNorm)(x) = forward(bn, x).rv +(bn::InvertibleBatchNorm)(x) = with_logabsdet_jacobian(bn, x)[1] function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -94,10 +94,10 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) v = reshape(bn.v, as...) x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m - return (rv=x, logabsdetjac=-logabsdetjac(bn, x)) + return (x, -logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).rv +(bn::Inverse{<:InvertibleBatchNorm})(y) = with_logabsdet_jacobian(bn, y)[1] function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 50070396..f2dbefff 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -101,7 +101,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) b = first(flow.b) log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b))) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (transformed, log_det_jacobian) end function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) @@ -175,5 +175,5 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real} return α0 end -logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac +logabsdetjac(flow::PlanarLayer, x) = forward(flow, x)[2] isclosedform(b::Inverse{<:PlanarLayer}) = false diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 7c79712c..11c3d799 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -63,7 +63,7 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat) (d - 1) * log(1 + β_hat * h_) + log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (transformed, log_det_jacobian) end function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) @@ -123,4 +123,4 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) return r end -logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac +logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x)[2] diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 8f081fcd..ef34c436 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -343,7 +343,7 @@ function rqs_forward( T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x)) if (x ≤ -widths[end]) || (x ≥ widths[end]) - return (rv = one(T) * x, logabsdetjac = zero(T) * x) + return (one(T) * x, zero(T) * x) end # Find which bin `x` is in @@ -376,9 +376,9 @@ function rqs_forward( numerator_y = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ)) y = h_k + numerator_y / denominator - return (rv = y, logabsdetjac = logjac) + return (y, logjac) end -function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 5c3fdb6b..144c7817 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -136,7 +136,7 @@ end # logjac = sum(_logjac) # (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) # logjac += sum(_logjac) -# return (rv = vcat(y_1, y_2), logabsdetjac = logjac) +# return (vcat(y_1, y_2), logjac) # end @generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} expr = Expr(:block) @@ -156,7 +156,7 @@ end push!(y_names, y_name) end - push!(expr.args, :(return (rv = vcat($(y_names...)), logabsdetjac = logjac))) + push!(expr.args, :(return (vcat($(y_names...)), logjac))) return expr end @@ -169,5 +169,5 @@ function forward(sb::Stacked, x::AbstractVector) logjac += sum(l) y end - return (rv = vcat(yinit, ys), logabsdetjac = logjac) + return (vcat(yinit, ys), logjac) end diff --git a/src/interface.jl b/src/interface.jl index c1ca115c..d2e8be7b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -89,17 +89,17 @@ Default implementation for `Inverse{<:Bijector}` is implemented as logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) """ - forward(b::Bijector, x) + with_logabsdet_jacobian(b::Bijector, x) Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. +returns a named tuple `(b(x), logabsdetjac(b, x))`. This defaults to the call above, but often one can re-use computation in the computation of the forward pass and the computation of the `logabsdetjac`. `forward` allows the user to take advantange of such efficiencies, if they exist. """ -forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) """ logabsdetjacinv(b::Bijector, y) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 1712ba2e..48afd2ff 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -86,14 +86,14 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res[1]) + res[2] end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) # batch-implementation for multivariate res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res[1]) + res[2] end function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) @@ -101,12 +101,12 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res[1]) + res[2] end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -114,12 +114,12 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] end # TODO: should eventually drop using `logpdf_with_trans` and replace with # res = forward(inv(td.transform), y) -# logpdf(td.dist, res.rv) .- res.logabsdetjac +# logpdf(td.dist, res[1]) .- res[2] function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end @@ -164,18 +164,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res[1]) + res[2], res[2]) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res[1]) + res[2], res[2]) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res[1]) + res[2], res[2]) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -183,14 +183,14 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac - return (lp, res.logabsdetjac) + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] + return (lp, res[2]) end # TODO: should eventually drop using `logpdf_with_trans` function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) + return (logpdf_with_trans(td.dist, res[1], true), res[2]) end """ diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index fcf1c402..ead12bdb 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -45,8 +45,8 @@ using Bijectors: @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) # forward - @test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x)) - @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) + @test forward(cl1, x) == (cl1(x), logabsdetjac(cl1, x)) + @test forward(icl1, cl1(x)) == (x, - logabsdetjac(cl1, x)) end @testset "Classic" begin diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index 63ba8c18..5a98f5d9 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -41,12 +41,12 @@ true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_loga # Forward f = forward(b, xs) - @test f.logabsdetjac ≈ logabsdetjac(b, xs) - @test f.rv ≈ b(xs) + @test f[2] ≈ logabsdetjac(b, xs) + @test f[1] ≈ b(xs) f = forward(b, Float32.(xs)) - @test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) - @test f.rv ≈ b(Float32.(xs)) + @test f[2] == logabsdetjac(b, Float32.(xs)) + @test f[1] ≈ b(Float32.(xs)) end @testset "0-dim parameter, 1-dim input" begin @@ -67,12 +67,12 @@ end # Forward f = forward(b, xs) - @test f.logabsdetjac ≈ logabsdetjac(b, xs) - @test f.rv ≈ b(xs) + @test f[2] ≈ logabsdetjac(b, xs) + @test f[1] ≈ b(xs) f = forward(b, Float32.(xs)) - @test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) - @test f.rv ≈ b(Float32.(xs)) + @test f[2] == logabsdetjac(b, Float32.(xs)) + @test f[1] ≈ b(Float32.(xs)) # Mixing of types # 1. Changes in input-type diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index a0fdb6f2..98919af9 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -17,21 +17,21 @@ function test_bijector_reals( ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) # Always want the following to hold - @test ires.rv ≈ x_true atol=tol - @test ires.logabsdetjac ≈ -logjac atol=tol + @test ires[1] ≈ x_true atol=tol + @test ires[2] ≈ -logjac atol=tol if isequal @test y ≈ y_true atol=tol # forward @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse @test logjac ≈ logjac_true # logjac forward - @test res.rv ≈ y_true atol=tol # forward using `forward` - @test res.logabsdetjac ≈ logjac_true atol=tol # logjac using `forward` + @test res[1] ≈ y_true atol=tol # forward using `forward` + @test res[2] ≈ logjac_true atol=tol # logjac using `forward` else @test y ≠ y_true # forward @test (@inferred ib(y)) ≈ x_true atol=tol # inverse @test logjac ≠ logjac_true # logjac forward - @test res.rv ≠ y_true # forward using `forward` - @test res.logabsdetjac ≠ logjac_true # logjac using `forward` + @test res[1] ≠ y_true # forward using `forward` + @test res[2] ≠ logjac_true # logjac using `forward` end end @@ -54,25 +54,25 @@ function test_bijector_arrays( # always want the following to hold @test ys isa typeof(ys_true) @test logjacs isa typeof(logjacs_true) - @test mean(abs, ires.rv - xs_true) ≤ tol - @test mean(abs, ires.logabsdetjac + logjacs) ≤ tol + @test mean(abs, ires[1] - xs_true) ≤ tol + @test mean(abs, ires[2] + logjacs) ≤ tol if isequal @test mean(abs, ys - ys_true) ≤ tol # forward @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward - @test mean(abs, res.rv - ys_true) ≤ tol # forward using `forward` - @test mean(abs, res.logabsdetjac - logjacs_true) ≤ tol # logjac `forward` - @test mean(abs, ires.logabsdetjac + logjacs_true) ≤ tol # inverse logjac `forward` + @test mean(abs, res[1] - ys_true) ≤ tol # forward using `forward` + @test mean(abs, res[2] - logjacs_true) ≤ tol # logjac `forward` + @test mean(abs, ires[2] + logjacs_true) ≤ tol # inverse logjac `forward` else # Don't want the following to be equal to their "true" values @test mean(abs, ys - ys_true) > tol # forward @test mean(abs, logjacs - logjacs_true) > tol # logjac forward - @test mean(abs, res.rv - ys_true) > tol # forward using `forward` + @test mean(abs, res[1] - ys_true) > tol # forward using `forward` # Still want the following to be equal to the COMPUTED values @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse - @test mean(abs, res.logabsdetjac - logjacs) ≤ tol # logjac forward using `forward` + @test mean(abs, res[2] - logjacs) ≤ tol # logjac forward using `forward` end end diff --git a/test/interface.jl b/test/interface.jl index cee597f6..27da4517 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -200,15 +200,15 @@ end @test size(x_) == size(x) @test size(xs_) == size(xs) - @test size(result.rv) == size(x) - @test size(results.rv) == size(xs) + @test size(result[1]) == size(x) + @test size(results[1]) == size(xs) - @test size(iresult.rv) == size(y) - @test size(iresults.rv) == size(ys) + @test size(iresult[1]) == size(y) + @test size(iresults[1]) == size(ys) # Values @test ys ≈ hcat([b(xs[:, i]) for i = 1:size(xs, 2)]...) - @test ys ≈ results.rv + @test ys ≈ results[1] if D == 0 # Sizes @@ -220,8 +220,8 @@ end @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - @test size(results.logabsdetjac) == size(xs, ) - @test size(iresults.logabsdetjac) == size(ys, ) + @test size(results[2]) == size(xs, ) + @test size(iresults[2]) == size(ys, ) # Values b_logjac_ad = [(log ∘ abs)(ForwardDiff.derivative(b, xs[i])) for i = 1:length(xs)] @@ -234,8 +234,8 @@ end @test logabsdetjac.(b, param(xs)) == @inferred(logabsdetjac(b, param(xs))) @test logabsdetjac.(ib, param(ys)) == @inferred(logabsdetjac(ib, param(ys))) - @test results.logabsdetjac ≈ vec(logabsdetjac.(b, xs)) - @test iresults.logabsdetjac ≈ vec(logabsdetjac.(ib, ys)) + @test results[2] ≈ vec(logabsdetjac.(b, xs)) + @test iresults[2] ≈ vec(logabsdetjac.(ib, ys)) elseif D == 1 @test y == ys[:, 1] # Comparing sizes instead of lengths ensures we catch errors s.t. @@ -247,15 +247,15 @@ end @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - @test size(results.logabsdetjac) == (size(xs, 2), ) - @test size(iresults.logabsdetjac) == (size(ys, 2), ) + @test size(results[2]) == (size(xs, 2), ) + @test size(iresults[2]) == (size(ys, 2), ) # Test all values @test @inferred(logabsdetjac(b, xs)) ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) @test @inferred(logabsdetjac(ib, ys)) ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - @test results.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) - @test iresults.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) + @test results[2] ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) + @test iresults[2] ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) # FIXME: `SimplexBijector` results in ∞ gradient if not in the domain if !contains(t -> t isa SimplexBijector, b) @@ -575,17 +575,17 @@ end res1 = forward(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray - @test sb1([x, x, y, y]) ≈ res1.rv + @test sb1([x, x, y, y]) ≈ res1[1] @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-6 - @test res1.logabsdetjac ≈ 0 atol=1e-6 + @test res1[2] ≈ 0 atol=1e-6 sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array res2 = forward(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray - @test sb2([x, x, y, y]) ≈ res2.rv + @test sb2([x, x, y, y]) ≈ res2[1] @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 - @test res2.logabsdetjac ≈ 0.0 atol=1e-12 + @test res2[2] ≈ 0.0 atol=1e-12 # `logabsdetjac` with AD b = MyADBijector(d) @@ -595,17 +595,17 @@ end res1 = forward(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray - @test sb1([x, x, y, y]) == res1.rv + @test sb1([x, x, y, y]) == res1[1] @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-12 - @test res1.logabsdetjac ≈ 0.0 atol=1e-12 + @test res1[2] ≈ 0.0 atol=1e-12 sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array res2 = forward(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray - @test sb2([x, x, y, y]) == res2.rv + @test sb2([x, x, y, y]) == res2[1] @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 - @test res2.logabsdetjac ≈ 0.0 atol=1e-12 + @test res2[2] ≈ 0.0 atol=1e-12 # value-test x = ones(3) @@ -613,9 +613,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] - @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] + @test res[1] == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) # TODO: change when we have dimensionality in the type @@ -624,9 +624,9 @@ end res = @inferred forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -637,9 +637,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -651,9 +651,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -664,9 +664,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -748,7 +748,7 @@ end x = [.5, 1.] @test sb(x) == x @test logabsdetjac(sb, x) == 0 - @test forward(sb, x) == (rv = x, logabsdetjac = zero(eltype(x))) + @test forward(sb, x) == (x, zero(eltype(x))) end end diff --git a/test/norm_flows.jl b/test/norm_flows.jl index dbbbc36f..9fadf573 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -26,7 +26,7 @@ end flow = PlanarLayer(2) z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) - our_method = sum(forward(flow, z).logabsdetjac) + our_method = sum(forward(flow, z)[2]) @test our_method ≈ forward_diff @test inv(flow)(flow(z)) ≈ z @@ -74,7 +74,7 @@ end flow = RadialLayer(2) z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) - our_method = sum(forward(flow, z).logabsdetjac) + our_method = sum(forward(flow, z)[2]) @test our_method ≈ forward_diff @test inv(flow)(flow(z)) ≈ z rtol=0.2 @@ -103,9 +103,9 @@ end x = rand(d) y = flow.transform(x) res = forward(flow.transform, x) - lp = logpdf_forward(flow, x, res.logabsdetjac) + lp = logpdf_forward(flow, x, res[2]) - @test res.rv ≈ y + @test res[1] ≈ y @test logpdf(flow, y) ≈ lp rtol=0.1 # flow with unconstrained-to-constrained