Skip to content

Commit

Permalink
Replace Base.inv with InverseFunctions.inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Dec 10, 2021
1 parent 000c83f commit 0c1bf48
Show file tree
Hide file tree
Showing 25 changed files with 168 additions and 165 deletions.
38 changes: 19 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ The following table lists mathematical operations for a bijector and the corresp

| Operation | Method | Automatic |
|:------------------------------------:|:-----------------:|:-----------:|
| `b ↦ b⁻¹` | `inv(b)` ||
| `b ↦ b⁻¹` | `inverse(b)` ||
| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` ||
| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` ||
| `x ↦ b(x)` | `b(x)` | × |
| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × |
| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × |
| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD |
| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` ||
| `p ↦ q := b_* p` | `q = transformed(p, b)` ||
Expand Down Expand Up @@ -123,7 +123,7 @@ true
What about `invlink`?

```julia
julia> b⁻¹ = inv(b)
julia> b⁻¹ = inverse(b)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> b⁻¹(y)
Expand All @@ -133,7 +133,7 @@ julia> b⁻¹(y) ≈ invlink(dist, y)
true
```

Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true.
Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true.

#### Dimensionality
One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`:
Expand Down Expand Up @@ -162,7 +162,7 @@ true
And since `Composed isa Bijector`:

```julia
julia> id_x = inv(id_y)
julia> id_x = inverse(id_y)
Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0)))

julia> id_x(x) x
Expand Down Expand Up @@ -201,7 +201,7 @@ julia> logpdf_forward(td, x)

#### `logabsdetjac` and `forward`

In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inv(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method
In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method

```julia
julia> logabsdetjac(b⁻¹, y)
Expand All @@ -228,7 +228,7 @@ julia> with_logabsdet_jacobian(b, x)
Similarily

```julia
julia> forward(inv(b), y)
julia> forward(inverse(b), y)
(0.3688868996596376, -1.4575353795716655)
```

Expand All @@ -241,7 +241,7 @@ At this point we've only shown that we can replicate the existing functionality.
julia> y = rand(td) # ∈ ℝ
0.999166054552483

julia> x = inv(td.transform)(y) # transform back to interval [0, 1]
julia> x = inverse(td.transform)(y) # transform back to interval [0, 1]
0.7308945834125756
```

Expand All @@ -261,7 +261,7 @@ Beta{Float64}(α=2.0, β=2.0)
julia> b = bijector(dist) # (0, 1) → ℝ
Logit{Float64}(0.0, 1.0)

julia> b⁻¹ = inv(b) # ℝ → (0, 1)
julia> b⁻¹ = inverse(b) # ℝ → (0, 1)
Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))

julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1)
Expand All @@ -280,7 +280,7 @@ It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while
```julia
td = transformed(Beta())

inv(td.transform)(rand(td))
inverse(td.transform)(rand(td))
```

will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._
Expand Down Expand Up @@ -335,7 +335,7 @@ julia> # Construct the transform
bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists
(Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}())

julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained
julia> ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained
(Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}()))

julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector
Expand Down Expand Up @@ -411,7 +411,7 @@ Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bo
```julia
julia> d = MvNormal(zeros(2), ones(2));

julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta())));
julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta())));

julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)]
Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2))
Expand Down Expand Up @@ -542,15 +542,15 @@ Logit{Float64}(0.0, 1.0)
julia> b(0.6)
0.4054651081081642

julia> inv(b)(y)
julia> inverse(b)(y)
Tracked 2-element Array{Float64,1}:
0.3078149833748082
0.72380041667891

julia> logabsdetjac(b, 0.6)
1.4271163556401458

julia> logabsdetjac(inv(b), y) # defaults to `- logabsdetjac(b, inv(b)(x))`
julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))`
Tracked 2-element Array{Float64,1}:
-1.546158373866469
-1.6098711387913573
Expand Down Expand Up @@ -614,10 +614,10 @@ julia> logabsdetjac(b_ad, 0.6)
julia> y = b_ad(0.6)
0.4054651081081642

julia> inv(b_ad)(y)
julia> inverse(b_ad)(y)
0.6

julia> logabsdetjac(inv(b_ad), y)
julia> logabsdetjac(inverse(b_ad), y)
-1.4271163556401458
```
Expand Down Expand Up @@ -666,7 +666,7 @@ help?> Bijectors.Composed

A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively.

Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methdos, e.g. inv.
Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse.

If you want to use an Array as the container instead you can do

Expand Down Expand Up @@ -714,7 +714,7 @@ The distribution interface consists of:
#### Methods
The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`.
- `(b::Bijector)(x)`: implements the transform of the `Bijector`
- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`.
- `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))).
- `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner.
- `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation.
Expand All @@ -726,7 +726,7 @@ For `TransformedDistribution`, together with default implementations for `Distri
- `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d`
- `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`.
- `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inv(b), b(x))` depending on which is most efficient.
- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient.
# Bibliography
1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6).
Expand Down
11 changes: 7 additions & 4 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

import ChangesOfVariables: with_logabsdet_jacobian
import InverseFunctions: inverse

import ChainRulesCore
import Functors
import InverseFunctions
import IrrationalConstants
import LogExpFunctions
import Roots
Expand Down Expand Up @@ -124,7 +124,7 @@ end
# Distributions

link(d::Distribution, x) = bijector(d)(x)
invlink(d::Distribution, y) = inv(bijector(d))(y)
invlink(d::Distribution, y) = inverse(bijector(d))(y)
function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
Expand Down Expand Up @@ -191,14 +191,14 @@ function invlink(
y::AbstractVecOrMat{<:Real},
::Val{proj}=Val(true),
) where {proj}
return inv(SimplexBijector{proj}())(y)
return inverse(SimplexBijector{proj}())(y)
end
function invlink_jacobian(
d::Dirichlet,
y::AbstractVector{<:Real},
::Val{proj}=Val(true),
) where {proj}
return jacobian(inv(SimplexBijector{proj}()), y)
return jacobian(inverse(SimplexBijector{proj}()), y)
end

## Matrix
Expand Down Expand Up @@ -254,6 +254,9 @@ include("chainrules.jl")

Base.@deprecate forward(b::AbstractBijector, x) with_logabsdet_jacobian(b, x)

import Base.inv
Base.@deprecate inv(b::AbstractBijector) inverse(b)

# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
maporbroadcast(f, x::AbstractArray...) = f.(x...)
Expand Down
6 changes: 3 additions & 3 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ A `Bijector` representing composition of bijectors. `composel` and `composer` re
`Composed` for which application occurs from left-to-right and right-to-left, respectively.
Note that all the alternative ways of constructing a `Composed` returns a `Tuple` of bijectors.
This ensures type-stability of implementations of all relating methdos, e.g. `inv`.
This ensures type-stability of implementations of all relating methdos, e.g. `inverse`.
If you want to use an `Array` as the container instead you can do
Expand All @@ -41,7 +41,7 @@ Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}()))
julia> (b ∘ b)(1.0) == exp(exp(1.0)) # evaluation
true
julia> inv(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion
julia> inverse(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion
true
julia> logabsdetjac(b ∘ b, 1.0) # determinant of jacobian
Expand Down Expand Up @@ -153,7 +153,7 @@ end
(::Identity{N}, b::Bijector{N}) where {N} = b
(b::Bijector{N}, ::Identity{N}) where {N} = b

inv(ct::Composed) = Composed(reverse(map(inv, ct.ts)))
inverse(ct::Composed) = Composed(reverse(map(inv, ct.ts)))

# # TODO: should arrays also be using recursive implementation instead?
function (cb::Composed{<:AbstractArray{<:Bijector}})(x)
Expand Down
2 changes: 1 addition & 1 deletion src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
if possible.
=#
return -logabsdetjac(inv(b), (b(X)))
return -logabsdetjac(inverse(b), (b(X)))
end
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(X) do x
Expand Down
4 changes: 2 additions & 2 deletions src/bijectors/coupling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ julia> cl(x)
2.0
3.0
julia> inv(cl)(cl(x))
julia> inverse(cl)(cl(x))
3-element Array{Float64,1}:
1.0
2.0
Expand Down Expand Up @@ -214,7 +214,7 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector)
y_1, y_2, y_3 = partition(cl.mask, y)

b = cl.θ(y_2)
ib = inv(b)
ib = inverse(b)

return combine(cl.mask, ib(y_1), y_2, y_3)
end
Expand Down
4 changes: 2 additions & 2 deletions src/bijectors/exp_log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Log() = Log{0}()
(b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y)
(b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x)

inv(b::Exp{N}) where {N} = Log{N}()
inv(b::Log{N}) where {N} = Exp{N}()
inverse(b::Exp{N}) where {N} = Log{N}()
inverse(b::Log{N}) where {N} = Exp{N}()

logabsdetjac(b::Exp{0}, x::Real) = x
logabsdetjac(b::Exp{0}, x::AbstractVector) = x
Expand Down
2 changes: 1 addition & 1 deletion src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function (b::LeakyReLU{<:Any, 0})(x::Real)
end
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x)

function Base.inv(b::LeakyReLU{<:Any,N}) where N
function inverse(b::LeakyReLU{<:Any,N}) where N
invα = inv.(b.α)
return LeakyReLU{typeof(invα),N}(invα)
end
Expand Down
14 changes: 7 additions & 7 deletions src/bijectors/named_bijector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ names_to_bijectors(b::NamedBijector) = b.bs
return :($(exprs...), )
end

@generated function Base.inv(b::NamedBijector{names}) where {names}
return :(NamedBijector(($([:($n = inv(b.bs.$n)) for n in names]...), )))
@generated function inverse(b::NamedBijector{names}) where {names}
return :(NamedBijector(($([:($n = inverse(b.bs.$n)) for n in names]...), )))
end

@generated function logabsdetjac(b::NamedBijector{names}, x::NamedTuple) where {names}
Expand All @@ -78,10 +78,10 @@ See also: [`Inverse`](@ref)
struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector
orig::B
end
Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb)
Base.inv(ni::NamedInverse) = ni.orig
inverse(nb::AbstractNamedBijector) = NamedInverse(nb)
inverse(ni::NamedInverse) = ni.orig

logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inv(ni), ni(y))
logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inverse(ni), ni(y))

##########################
### `NamedComposition` ###
Expand All @@ -107,7 +107,7 @@ composel(bs::AbstractNamedBijector...) = NamedComposition(bs)
composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs))
(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1)

inv(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs)))
inverse(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs)))

function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x)
@assert length(cb.bs) > 0
Expand Down Expand Up @@ -232,7 +232,7 @@ end
) where {target, deps, F}
return quote
b = ni.orig.f($([:(x.$d) for d in deps]...))
return merge(x, ($target = inv(b)(x.$target), ))
return merge(x, ($target = inverse(b)(x.$target), ))
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/bijectors/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y)
@assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode."
dims = ndims(y)
as = ntuple(i -> i == ndims(y) - 1 ? size(y, i) : 1, dims)
bn = inv(invbn)
bn = inverse(invbn)
s = reshape(exp.(bn.logs), as...)
b = reshape(bn.b, as...)
m = reshape(bn.m, as...)
Expand Down
6 changes: 3 additions & 3 deletions src/bijectors/permute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ julia> b4([1., 2., 3.])
1.0
3.0
julia> inv(b1)
julia> inverse(b1)
Permute{LinearAlgebra.Transpose{Int64,Array{Int64,2}}}([0 1 0; 1 0 0; 0 0 1])
julia> inv(b1)(b1([1., 2., 3.]))
julia> inverse(b1)(b1([1., 2., 3.]))
3-element Array{Float64,1}:
1.0
2.0
Expand Down Expand Up @@ -151,7 +151,7 @@ end


@inline (b::Permute)(x::AbstractVecOrMat) = b.A * x
@inline inv(b::Permute) = Permute(transpose(b.A))
@inline inverse(b::Permute) = Permute(transpose(b.A))

logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x))
logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2))
2 changes: 1 addition & 1 deletion src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a)
(b::Shift)(x) = b.a .+ x
(b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x)

inv(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a)
inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a)

# FIXME: implement custom adjoint to ensure we don't get tracking
logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N))
Expand Down
4 changes: 2 additions & 2 deletions src/bijectors/simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ function (ib::Inverse{<:SimplexBijector{1}})(
_simplex_inv_bijector!(X, Y, ib.orig)
end
function (ib::Inverse{<:SimplexBijector{2, proj}})(Y::AbstractMatrix) where {proj}
inv(SimplexBijector{1, proj}())(Y)
inverse(SimplexBijector{1, proj}())(Y)
end
function (ib::Inverse{<:SimplexBijector{2, proj}})(X::AbstractMatrix, Y::AbstractMatrix) where {proj}
inv(SimplexBijector{1, proj}())(X, Y)
inverse(SimplexBijector{1, proj}())(X, Y)
end
(ib::Inverse{<:SimplexBijector{2}})(Y::AbstractArray{<:AbstractMatrix}) = map(ib, Y)
function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1})
Expand Down
8 changes: 4 additions & 4 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs)

stack(bs::Bijector{0}...) = Stacked(bs)

# For some reason `inv.(sb.bs)` was unstable... This works though.
inv(sb::Stacked) = Stacked(map(inv, sb.bs), sb.ranges)
# For some reason `inverse.(sb.bs)` was unstable... This works though.
inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges)
# map is not type stable for many stacked bijectors as a large tuple
# hence the generated function
@generated function inv(sb::Stacked{A}) where {A <: Tuple}
@generated function inverse(sb::Stacked{A}) where {A <: Tuple}
exprs = []
for i = 1:length(A.parameters)
push!(exprs, :(inv(sb.bs[$i])))
push!(exprs, :(inverse(sb.bs[$i])))
end
:(Stacked(($(exprs...), ), sb.ranges))
end
Expand Down
Loading

0 comments on commit 0c1bf48

Please sign in to comment.