Skip to content
This repository has been archived by the owner on Aug 31, 2022. It is now read-only.

Commit

Permalink
add various options to set epsilon in a more clever range.
Browse files Browse the repository at this point in the history
when no epsilon is passed, default behaviour is to compute average cost contained in cost_matrix (also works for grid and pointcloud) and use a 1/20 th  of that value as regularizer.

when epsilon is passed, then Geometry uses that value directly, unless user specifies a relative_epsilon=True parameter, in which case epsilon is understood as a fraction of the mean value observed in cost matrix.

in sinkhorn divergence, this behaviour could lead to possibly 3 different epsilon values if none is specified. To account for this introduction of a flag, True by default, that ensures the same epsilon scheduler is applied 3 times.

issues raised by this CL:
- handling of the case where epsilon is itself a scheduler might be mishandled in corner cases.
- In GW in particular, there might be issues in ensuring more than just the epsilon value (i.e. geometric decay parameters) are passed on.

PiperOrigin-RevId: 380205975
  • Loading branch information
marcocuturi committed Jun 18, 2021
1 parent 6e2fa55 commit 152af17
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 88 deletions.
4 changes: 2 additions & 2 deletions ott/core/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def is_online(geom):
marginal_x, marginal_y, geom_x, geom_y, loss)
cost_matrix = marginal_dep_term - apply_cost_fn(geom_y)(
tmp.T, axis=1, fn=loss.right_y).T
return geometry.Geometry(cost_matrix=cost_matrix,
epsilon=geom._epsilon, **kwargs)
return geometry.Geometry(
cost_matrix=cost_matrix, epsilon=geom._epsilon_init, **kwargs)


def _marginal_dependent_cost(marginal_x, marginal_y, geom_x, geom_y, loss):
Expand Down
5 changes: 2 additions & 3 deletions ott/core/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ def sinkhorn(
unbalanced optimal transport problem with a variable matrix `P` of size ``n``
x ``m``:
:math:`\arg\min_{P>0} <P,C> -\epsilon H(P) + \rho_a \text{KL}(P1 | a) + \rho_b \text{KL}(P^T1 | b)`
:math:`\arg\min_{P>0} <P,C> -\epsilon \text{KL}(P | ab^T) + \rho_a \text{KL}(P1 | a) + \rho_b \text{KL}(P^T1 | b)`
where :math:`H` is the Shannon entropy, and :math:`KL` is the generalized
Kullback-Leibler divergence.
where :math:`KL` is the generalized Kullback-Leibler divergence.
The very same primal problem can also be written using a kernel :math:`K`
instead of a cost :math:`C` as well:
Expand Down
37 changes: 29 additions & 8 deletions ott/geometry/epsilon_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,38 @@ class Epsilon:
"""Scheduler class for the regularization parameter epsilon."""

def __init__(self,
target: float = 1e-2,
init: float = 1.0,
decay: float = 1.0):
self.target = target
self._init = init
self._decay = decay
target: Optional[float] = None,
scale: Optional[float] = None,
init: Optional[float] = None,
decay: Optional[float] = None):
"""Initializes a scheduler using possibly geometric decay.
The entropic regularization value is given either directly or relative to
a scale. In that case, the initial ``target`` value is understood to be a
proportion of the ``scale``. Both are recorded and merged in the ``target``
field which is built from those two parameters.
Args:
target: the epsilon regularizing value that is targeted, understood
as a multiple of scale.
scale: scale to be used with target_init to define a target epsilon.
init: initial value when using epsilon scheduling, understood as a
a fraction of scale as well.
decay: geometric decay factor, smaller than 1.
"""
self._target_init = .01 if target is None else target
self._scale = 1.0 if scale is None else scale
self._init = 1.0 if init is None else init
self._decay = 1.0 if decay is None else decay

@property
def target(self):
return self._target_init * self._scale

def at(self, iteration: Optional[int] = 1) -> float:
if iteration is None:
return self.target
init = jnp.where(self._decay < 1.0, self._init, self.target)
init = jnp.where(self._decay < 1.0, self._init * self._scale, self.target)
decay = jnp.where(self._decay < 1.0, self._decay, 1.0)
return jnp.maximum(init * decay**iteration, self.target)

Expand All @@ -47,7 +68,7 @@ def done_at(self, iteration):
return self.done(self.at(iteration))

def tree_flatten(self):
return (self.target, self._init, self._decay), None
return (self._target_init, self._scale, self._init, self._decay), None

@classmethod
def tree_unflatten(cls, aux_data, children):
Expand Down
Loading

0 comments on commit 152af17

Please sign in to comment.