Skip to content

Commit

Permalink
Replace forward by with_logabsdet_jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Dec 10, 2021
1 parent 31adbe8 commit 000c83f
Show file tree
Hide file tree
Showing 17 changed files with 139 additions and 135 deletions.
33 changes: 17 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The following table lists mathematical operations for a bijector and the corresp
| `x ↦ b(x)` | `b(x)` | × |
| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × |
| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD |
| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` ||
| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` ||
| `p ↦ q := b_* p` | `q = transformed(p, b)` ||
| `y ∼ q` | `y = rand(q)` ||
| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` ||
Expand Down Expand Up @@ -221,18 +221,18 @@ true
which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use:

```julia
julia> forward(b, x)
(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655)
julia> with_logabsdet_jacobian(b, x)
(-0.5369949942509267, 1.4575353795716655)
```

Similarily

```julia
julia> forward(inv(b), y)
(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655)
(0.3688868996596376, -1.4575353795716655)
```

In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented).
In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented).

#### Sampling from `TransformedDistribution`
At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`:
Expand Down Expand Up @@ -481,7 +481,7 @@ julia> Flux.params(flow)
Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)])
```

Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.
Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path.

```julia
julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass
Expand Down Expand Up @@ -555,28 +555,29 @@ Tracked 2-element Array{Float64,1}:
-1.546158373866469
-1.6098711387913573

julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))`
(0.4054651081081642, 1.4271163556401458)
```

For further efficiency, one could manually implement `forward(b::Logit, x)`:
For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`:

```julia
julia> import Bijectors: forward, Logit
julia> import ChangesOfVariables: with_logabsdet_jacobian

julia> function forward(b::Logit{<:Real}, x)
julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x)
totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not
y = logit.(totally_worth_saving)
logjac = @. - log((b.b - x) * totally_worth_saving)
return (rv=y, logabsdetjac = logjac)
return (y, logjac)
end
forward (generic function with 16 methods)

julia> forward(b, 0.6)
(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458)
julia> with_logabsdet_jacobian(b, 0.6)
(0.4054651081081642, 1.4271163556401458)

julia> @which forward(b, 0.6)
forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
julia> @which with_logabsdet_jacobian(b, 0.6)
with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2
```
As you can see it's a very contrived example, but you get the idea.
Expand Down Expand Up @@ -715,7 +716,7 @@ The following methods are implemented by all subtypes of `Bijector`, this also i
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner.
- `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner.
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
- `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency.
- `dimension(b::Bijector)`: returns the dimensionality of `b`.
Expand Down
5 changes: 4 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

import ChangesOfVariables: with_logabsdet_jacobian

import ChainRulesCore
import ChangesOfVariables
import Functors
import InverseFunctions
import IrrationalConstants
Expand Down Expand Up @@ -251,6 +252,8 @@ include("utils.jl")
include("interface.jl")
include("chainrules.jl")

Base.@deprecate forward(b::AbstractBijector, x) with_logabsdet_jacobian(b, x)

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)
Expand Down
20 changes: 10 additions & 10 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ function logabsdetjac(cb::Composed, x)
y, logjac = forward(cb.ts[1], x)
for i = 2:length(cb.ts)
res = forward(cb.ts[i], y)
y = res.rv
logjac += res.logabsdetjac
y = res[1]
logjac += res[2]
end

return logjac
Expand All @@ -195,8 +195,8 @@ end
for i = 2:N - 1
temp = gensym(:res)
push!(expr.args, :($temp = forward(cb.ts[$i], y)))
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
push!(expr.args, :(y = $temp[1]))
push!(expr.args, :(logjac += $temp[2]))
end
# don't need to evaluate the last bijector, only it's `logabsdetjac`
push!(expr.args, :(logjac += logabsdetjac(cb.ts[$N], y)))
Expand All @@ -212,10 +212,10 @@ function forward(cb::Composed, x)

for t in cb.ts[2:end]
res = forward(t, rv)
rv = res.rv
logjac = res.logabsdetjac + logjac
rv = res[1]
logjac = res[2] + logjac
end
return (rv=rv, logabsdetjac=logjac)
return (rv, logjac)
end


Expand All @@ -225,10 +225,10 @@ end
for i = 2:length(T.parameters)
temp = gensym(:temp)
push!(expr.args, :($temp = forward(cb.ts[$i], y)))
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
push!(expr.args, :(y = $temp[1]))
push!(expr.args, :(logjac += $temp[2]))
end
push!(expr.args, :(return (rv = y, logabsdetjac = logjac)))
push!(expr.args, :(return (y, logjac)))

return expr
end
14 changes: 7 additions & 7 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ end
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x)


# We implement `forward` by hand since we can re-use the computation of
# We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 0}, x::Real)
function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + !mask * one(x)
return (rv=J * x, logabsdetjac=log(abs(J)))
return (J * x, log(abs(J)))
end

# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
end
return (rv=J .* x, logabsdetjac=log.(abs.(J)))
return (J .* x, log.(abs.(J)))
end

# (N=1) Multivariate case
Expand All @@ -84,7 +84,7 @@ end
# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = let T = eltype(x), z = zero(T), o = one(T)
@. (x < z) * b.α + (x > z) * o
Expand All @@ -97,5 +97,5 @@ function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
end

y = J .* x
return (rv=y, logabsdetjac=logjac)
return (y, logjac)
end
22 changes: 11 additions & 11 deletions src/bijectors/named_bijector.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
abstract type AbstractNamedBijector <: AbstractBijector end

forward(b::AbstractNamedBijector, x) = (rv = b(x), logabsdetjac = logabsdetjac(b, x))
with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x))

#######################
### `NamedBijector` ###
Expand Down Expand Up @@ -125,8 +125,8 @@ function logabsdetjac(cb::NamedComposition, x)
y, logjac = forward(cb.bs[1], x)
for i = 2:length(cb.bs)
res = forward(cb.bs[i], y)
y = res.rv
logjac += res.logabsdetjac
y = res[1]
logjac += res[2]
end

return logjac
Expand All @@ -141,8 +141,8 @@ end
for i = 2:N - 1
temp = gensym(:res)
push!(expr.args, :($temp = forward(cb.bs[$i], y)))
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
push!(expr.args, :(y = $temp[1]))
push!(expr.args, :(logjac += $temp[2]))
end
# don't need to evaluate the last bijector, only it's `logabsdetjac`
push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y)))
Expand All @@ -158,10 +158,10 @@ function forward(cb::NamedComposition, x)

for t in cb.bs[2:end]
res = forward(t, rv)
rv = res.rv
logjac = res.logabsdetjac + logjac
rv = res[1]
logjac = res[2] + logjac
end
return (rv=rv, logabsdetjac=logjac)
return (rv, logjac)
end


Expand All @@ -171,10 +171,10 @@ end
for i = 2:length(T.parameters)
temp = gensym(:temp)
push!(expr.args, :($temp = forward(cb.bs[$i], y)))
push!(expr.args, :(y = $temp.rv))
push!(expr.args, :(logjac += $temp.logabsdetjac))
push!(expr.args, :(y = $temp[1]))
push!(expr.args, :(logjac += $temp[2]))
end
push!(expr.args, :(return (rv = y, logabsdetjac = logjac)))
push!(expr.args, :(return (y, logjac)))

return expr
end
Expand Down
12 changes: 6 additions & 6 deletions src/bijectors/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function Functors.functor(::Type{<:InvertibleBatchNorm}, x)
return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm
end

function forward(bn::InvertibleBatchNorm, x)
function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x)
dims = ndims(x)
size(x, dims - 1) == length(bn.b) ||
error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))")
Expand Down Expand Up @@ -76,12 +76,12 @@ function forward(bn::InvertibleBatchNorm, x)
logabsdetjac = (
fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims))
)
return (rv=rv, logabsdetjac=logabsdetjac)
return (rv, logabsdetjac)
end

logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac
logabsdetjac(bn::InvertibleBatchNorm, x) = with_logabsdet_jacobian(bn, x)[2]

(bn::InvertibleBatchNorm)(x) = forward(bn, x).rv
(bn::InvertibleBatchNorm)(x) = with_logabsdet_jacobian(bn, x)[1]

function forward(invbn::Inverse{<:InvertibleBatchNorm}, y)
@assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode."
Expand All @@ -94,10 +94,10 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y)
v = reshape(bn.v, as...)

x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m
return (rv=x, logabsdetjac=-logabsdetjac(bn, x))
return (x, -logabsdetjac(bn, x))
end

(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).rv
(bn::Inverse{<:InvertibleBatchNorm})(y) = with_logabsdet_jacobian(bn, y)[1]

function Base.show(io::IO, l::InvertibleBatchNorm)
print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))")
Expand Down
4 changes: 2 additions & 2 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real})
b = first(flow.b)
log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b)))

return (rv = transformed, logabsdetjac = log_det_jacobian)
return (transformed, log_det_jacobian)
end

function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real})
Expand Down Expand Up @@ -175,5 +175,5 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
return α0
end

logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac
logabsdetjac(flow::PlanarLayer, x) = forward(flow, x)[2]
isclosedform(b::Inverse{<:PlanarLayer}) = false
4 changes: 2 additions & 2 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat)
(d - 1) * log(1 + β_hat * h_)
+ log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r)
) # from eq(14)
return (rv = transformed, logabsdetjac = log_det_jacobian)
return (transformed, log_det_jacobian)
end

function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real})
Expand Down Expand Up @@ -123,4 +123,4 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat)
return r
end

logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac
logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x)[2]
6 changes: 3 additions & 3 deletions src/bijectors/rational_quadratic_spline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ function rqs_forward(
T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x))

if (x -widths[end]) || (x widths[end])
return (rv = one(T) * x, logabsdetjac = zero(T) * x)
return (one(T) * x, zero(T) * x)
end

# Find which bin `x` is in
Expand Down Expand Up @@ -376,9 +376,9 @@ function rqs_forward(
numerator_y = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ))
y = h_k + numerator_y / denominator

return (rv = y, logabsdetjac = logjac)
return (y, logjac)
end

function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real)
function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real)
return rqs_forward(b.widths, b.heights, b.derivatives, x)
end
6 changes: 3 additions & 3 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ end
# logjac = sum(_logjac)
# (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]])
# logjac += sum(_logjac)
# return (rv = vcat(y_1, y_2), logabsdetjac = logjac)
# return (vcat(y_1, y_2), logjac)
# end
@generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N}
expr = Expr(:block)
Expand All @@ -156,7 +156,7 @@ end
push!(y_names, y_name)
end

push!(expr.args, :(return (rv = vcat($(y_names...)), logabsdetjac = logjac)))
push!(expr.args, :(return (vcat($(y_names...)), logjac)))
return expr
end

Expand All @@ -169,5 +169,5 @@ function forward(sb::Stacked, x::AbstractVector)
logjac += sum(l)
y
end
return (rv = vcat(yinit, ys), logabsdetjac = logjac)
return (vcat(yinit, ys), logjac)
end
6 changes: 3 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ Default implementation for `Inverse{<:Bijector}` is implemented as
logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y))

"""
forward(b::Bijector, x)
with_logabsdet_jacobian(b::Bijector, x)
Computes both `transform` and `logabsdetjac` in one forward pass, and
returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`.
returns a named tuple `(b(x), logabsdetjac(b, x))`.
This defaults to the call above, but often one can re-use computation
in the computation of the forward pass and the computation of the
`logabsdetjac`. `forward` allows the user to take advantange of such
efficiencies, if they exist.
"""
forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x))
with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x))

"""
logabsdetjacinv(b::Bijector, y)
Expand Down
Loading

0 comments on commit 000c83f

Please sign in to comment.