Skip to content
This repository has been archived by the owner on Aug 31, 2022. It is now read-only.

Specifying Geometry with kernel matrix and without epsilon causes RecursionError #18

Closed
michalk8 opened this issue Nov 18, 2021 · 1 comment

Comments

@michalk8
Copy link

Code to reproduce:

from ott.core.sinkhorn import sinkhorn
import jax.numpy as jnp

geom_e = Geometry(kernel_matrix=jnp.ones((10, 10)), epsilon=1e-2)
print(geom_e.cost_matrix)  # ok
geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
print(geom.cost_matrix)  # raises the error below
---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
/tmp/ipykernel_246820/705571593.py in <module>
      2 print(geom_e.cost_matrix)
      3 geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
----> 4 print(geom.cost_matrix)

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    107   def cost_matrix(self):
    108     if self._cost_matrix is None:
--> 109       return -self.epsilon * jnp.log(self._kernel_matrix)
    110     return self._cost_matrix
    111 

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in epsilon(self)
    130   @property
    131   def epsilon(self):
--> 132     return self._epsilon.target
    133 
    134   @property

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in _epsilon(self)
    102       return self._epsilon_init
    103     eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
--> 104     return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs)
    105 
    106   @property

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in scale(self)
     92     if (self._scale is None) and (trigger is not None):  # for dry run
     93       return jnp.where(
---> 94           trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0)
     95     else:
     96       return self._scale

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in mean_cost_matrix(self)
    116   @property
    117   def mean_cost_matrix(self):
--> 118     if isinstance(self.shape[0], int) and (self.shape[0] > 0):
    119       return jnp.sum(self.apply_cost(jnp.ones((self.shape[0],)))) / (
    120           self.shape[0] * self.shape[1])

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
    134   @property
    135   def shape(self):
--> 136     mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
    137     if mat is not None:
    138       return mat.shape

... last 6 frames repeated, from the frame below ...

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    107   def cost_matrix(self):
    108     if self._cost_matrix is None:
--> 109       return -self.epsilon * jnp.log(self._kernel_matrix)
    110     return self._cost_matrix
    111 

RecursionError: maximum recursion depth exceeded while calling a Python object

Version: 0.1.17

marcocuturi added a commit that referenced this issue Nov 19, 2021
…h a kernel matrix. In that case default to 1 by testing whether epsilon_init was None and avoid "autoepsilon" computations that would cause an infinite recursion.
marcocuturi added a commit that referenced this issue Nov 19, 2021
Solving Issue #18 when no epsilon is passed to a geometry defined with a kernel matrix
@marcocuturi
Copy link
Contributor

Michal, thanks again for kindly finding out about this issue. My PR #20 should solve this but I will be a bit more rigorous and add a test. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants