Releases: google-research/ott
Introduce ability to force a jax.lax.scan in the sinkhorn iterations
This is proposed to handle bug #7.
When setting min_iterations=max_iterations
in the call to sinkhorn
, the loop is now by default a jax.lax.scan
.
Fine grained tuning of implicit differentiation + tests for Hessian.
In this release:
Following issue [ https://github.com//issues/7 ] a few clarifications on computing Hessian of reg_ot_cost, comparing it between unrolling and implicit differentiation. This triggers more fine tuning of the sinkhorn function, in particular:
-
The reg-OT problem computes min_x = F(x, y) where y stand for all inputs (geom, a, b) and x parameterizes the optimal transport (e.g. dual variables f, g). The reg-OT can be written x*(y) = argmax_x F(x,y) and the reg-OT cost is therefore equal to G(y) := F(x*(y), y). The Danskin approximation consists in considering the local function H(y) := F(x*, y) (where x* = x*(y) but is now seen as a constant, killing dependency w.r.t. y of x*(y)). The Danskin approximation is useful to compute a first order derivative, since the enveloppe theorem states that, locally, ∇H = ∇_2 G. However, that approximation is no longer valid to compute second order derivatives. Since this is the goal of this new release, we now let the user decide whether Danskin is used or not.
-
To differentiate further the vjp of sinkhorn, some further clean up was needed in the way the implicit function theorem was carrying out numerical stabilization of the linear solve.
-
Both hessians using implicit and backprop are tested.
Warm starts in GW
In this new release:
- option to use warm starts in GW. This is enabled by default and allows for faster convergence.
Fixes issues in evaluation of reg_ot_cost, implicit diff in inner iterations of GW; sort/ranks.
Fixes a major bug:
-
In previous releases, the computation of the objective was done in kernel space (i.e. using u'Kv), regardless of whether the Sinkhorn iterations were carried out using logsumexp (lsemode=True) or kernels. This was done in order to allow padding in weights without having blowups when differentiating. The hope was that even with small epsilon, since this evaluation was at the very end, underflows would not prove problematic. What was not anticipated were overflows.
-
In the current release, the objective is now properly computed using logsumexp or kernel expressions, up until the end. To remove the problem related to differentiations (more precisely, related to differentiating exp(log-of-sum(exp(log(0)))), a new implementation of lse with custom differentiation rules, that is robust to all input entries being -inf, and still returning a 0 gradient, has been provided.
-
added a num_targets option in sort and ranks.
-
implicit differentiation of GW now supported. This is a mixed implicit/backprop mode: backprop on the outer loop, implicit differentiation of Sinkhorn iterations.
Stability of implicit diff across accelerators + Gromov Wasserstein
In this release:
- Gromov Wasserstein implementation, still in beta.
- additional regularizers to stabilize implicit function diff to handle degeneracies, notably when running on GPU / TPU.
Faster implicit differentiation
- Faster implicit differentiation of optimal potentials (and therefore OT results) using more detailed computation of Hessian of objectives, solving linear system using Schur complement.
- Unbalanced, regularized, Bures distance.
Enveloppe theorem differentiation for reg_ot_cost, better 0 handling for weights a,b in gradients, new operators in soft_sort.
In this release:
-
computation of gradients of reg_ot_cost simplified. By default reg_ot_cost was evaluated as a function of the optimal dual potentials f, g. The differentiation of reg_ot_cost w.r.t. geom, a or b, therefore required a call to the custom_vjp's of f and g (implicit_differentiation=True) or using backprop (..=False) . However, by virtue of the enveloppe theorem, reg_ot_cost (the quantity minimized by the sinkhorn algorithm) does not require differentiating w.r.t. arg-minimizers (here dual potentials f,g). As a result, computing the gradients of reg_ot_cost w.r.t. parameters in geom, a and b should be much faster.. Notice this result only makes sense numerically if the threshold used to stop the sinkhorn algorithm is small.
-
handling of 0 values in weight vectors a, b is now improved. When computing gradients of reg_ot_cost w.r.t. , for instance, locations x, y or others, zero weights in a and b do not result in nan gradients.
-
various novel soft sort/ranks/quantiles operators added in soft_sort, notably topk and average of topk rows in a matrix when sorted against a given criterion.
Changes in discrete_barycenter and other updates.
in this new release:
- apply_cost function, to apply a cost matrix to a vector/matrix. do this on a budget whenever possible, notably for grid.
- possibility to initialize wasserstein barycenter solver by passing dual variables.
- heuristic by default to initialize wasserstein barycenter problems.
- change in CostFn naming, now dotprod -> - pairwise.
1st stable release
First stable release of the toolbox, with docs.