Skip to content

Commit

Permalink
Add Fisher information
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Dec 29, 2020
1 parent c7b7828 commit fe4eb9d
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 135 deletions.
143 changes: 50 additions & 93 deletions efax/exponential_family.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,21 @@
from __future__ import annotations

from functools import partial, reduce
from typing import Any, Generic, Iterable, Tuple, Type, TypeVar, final, get_type_hints
from typing import Any, Callable, Generic, Iterable, Type, TypeVar, final, get_type_hints

import numpy as np
from chex import Array
from jax import grad, jacfwd
from jax import numpy as jnp
from jax.tree_util import tree_map, tree_reduce
from tjax import RealArray, Shape, custom_jvp, field_values, jit
from jax import vmap
from tjax import RealArray, field_values, jit

from .parameter import parameter_names_axes, parameter_names_values_axes
from .parameter import parameter_names_values_axes
from .parametrization import Parametrization

__all__ = ['NaturalParametrization', 'ExpectationParametrization']


EP = TypeVar('EP', bound='ExpectationParametrization[Any]')
T = TypeVar('T', bound='Parametrization')


class Parametrization:
# Magic methods --------------------------------------------------------------------------------
def __init_subclass__(cls) -> None:
super().__init_subclass__()
if cls.__name__ in ['VonMisesFisher']:
return

# Apply jit.
for name in ['log_normalizer',
'nat_to_exp',
'sufficient_statistics',
'cross_entropy',
'entropy',
'carrier_measure',
'expected_carrier_measure',
'pdf']:
super_cls = super(cls, cls)
if not hasattr(cls, name):
continue
original_method = getattr(cls, name)
if hasattr(super_cls, name) and getattr(super_cls, name) is original_method:
continue # We only need to jit new methods.
method = jit(original_method)
setattr(cls, f'_original_{name}', method)

if name != 'log_normalizer':
setattr(cls, name, method)
continue

method_jvp: Any = custom_jvp(method)

def ln_jvp(primals: Tuple[NaturalParametrization[Any]],
tangents: Tuple[NaturalParametrization[Any]]) -> Tuple[RealArray, RealArray]:
q, = primals
q_dot, = tangents
y = q.log_normalizer()
p = q.to_exp()
y_dot = tree_dot_final(q_dot, p)
return y, y_dot

method_jvp.defjvp(ln_jvp)

setattr(cls, name, method_jvp)

# New methods ----------------------------------------------------------------------------------
def flattened(self) -> Array:
def flatten_parameter(x: Array) -> Array:
return jnp.reshape(x, (*self.shape(), -1))
return tree_reduce(partial(jnp.append, axis=-1), tree_map(flatten_parameter, self))

@classmethod
def unflattened(cls: Type[T], flattened: Array, **kwargs: Any) -> T:
# Count the fields with 0, 1, and 2 axes. Subtract the shape of flattened from the 0 count.
totals = np.zeros(3, dtype=np.int_)
totals[0] -= flattened.shape[-1]
for _, n_axes in parameter_names_axes(cls):
if not (0 <= n_axes <= 2):
raise ValueError
totals[n_axes] += 1

# Solve the quadratic equation and select the largest positive root.
roots = np.roots(list(reversed(totals)))
roots = list(roots)
if not roots:
root = 1
else:
root = int(max(roots))
if root < 0:
raise ValueError

# Unflatten.
shape = flattened.shape[:-1]
consumed = 0
for name, n_axes in parameter_names_axes(cls):
k = root ** n_axes
kwargs[name] = np.reshape(flattened[..., consumed: consumed + k],
shape + (root,) * n_axes)
consumed += k

return cls(**kwargs) # type: ignore

# Abstract methods -----------------------------------------------------------------------------
def shape(self) -> Shape:
raise NotImplementedError


def dot_final(x: Array, y: Array, n_axes: int) -> RealArray:
Expand All @@ -115,7 +29,7 @@ def dot_final(x: Array, y: Array, n_axes: int) -> RealArray:
def tree_dot_final(x: NaturalParametrization[Any], y: Any) -> RealArray:
def dotted_fields() -> Iterable[Array]:
for (_, xf, n_axes), yf in zip(parameter_names_values_axes(x),
field_values(y, static=False)):
field_values(y, static=False)):
yield dot_final(xf, yf, n_axes)
return reduce(jnp.add, dotted_fields())

Expand All @@ -124,6 +38,8 @@ class NaturalParametrization(Parametrization, Generic[EP]):
"""
The natural parametrization of an exponential family distribution.
"""
T = TypeVar('T', bound='NaturalParametrization[EP]')

# Abstract methods -----------------------------------------------------------------------------
def log_normalizer(self) -> RealArray:
"""
Expand Down Expand Up @@ -180,10 +96,51 @@ def pdf(self, x: Array) -> RealArray:
- self.log_normalizer()
+ self.carrier_measure(x))

@final
def fisher_information(self: T, diagonal: bool = False, trace: bool = False) -> T:
"""
Args:
diagonal: If true, return only the diagonal elements of the Fisher information matrices.
trace: If true, return the trace of the Fisher information matrices.
Returns: The Fisher information stored in a NaturalParametrization object R whose fields
are:
* A scalar if trace is true.
* An array of the same shape as self if diagonal is true.
* Otherwise, a NaturalParametrization object whose fields are arrays.
"""
fisher_information = self._fisher_helper(len(self.shape()))

if not trace and not diagonal:
return fisher_information

kwargs = {}
f = jnp.trace if trace else jnp.diagonal
for name, value, axes in parameter_names_values_axes(fisher_information):
kwargs[name] = transform(f, getattr(value, name), axes)
return fisher_information.replace(**kwargs)

# Private methods ------------------------------------------------------------------------------
@partial(jit, static_argnums=1)
def _fisher_helper(self: T, len_shape: int) -> T:
fisher_info_f = jacfwd(grad(type(self).log_normalizer))
for _ in range(len_shape):
fisher_info_f = vmap(fisher_info_f)
return fisher_info_f(self)


NP = TypeVar('NP', bound=NaturalParametrization[Any])


def transform(f: Callable[..., Array], array: Array, axes: int) -> Array:
if axes == 0:
return array
if axes == 1:
return f(array, axis1=-2, axis2=-1)
if axes == 2:
return f(f(array, axis1=-3, axis2=-1), axis1=-2, axis2=-1)
raise ValueError


class ExpectationParametrization(Parametrization, Generic[NP]):
"""
The expectation parametrization of an exponential family distribution. This class also doubles
Expand Down
2 changes: 1 addition & 1 deletion efax/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterable, Type, Union, Tuple
from typing import TYPE_CHECKING, Any, Iterable, Tuple, Type, Union

from tjax import Array, field, field_names_values_metadata, fields

Expand Down
107 changes: 107 additions & 0 deletions efax/parametrization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Any, Tuple, Type, TypeVar

import numpy as np
from chex import Array
from jax import numpy as jnp
from jax.tree_util import tree_map, tree_reduce
from tjax import RealArray, Shape, custom_jvp, jit

from .parameter import parameter_names_axes

__all__ = ['Parametrization']


T = TypeVar('T', bound='Parametrization')


class Parametrization:
# Magic methods --------------------------------------------------------------------------------
def __init_subclass__(cls) -> None:
super().__init_subclass__()
if cls.__name__ in ['VonMisesFisher']:
return

from .exponential_family import tree_dot_final # pylint: disable=import-outside-toplevel

# Apply jit.
for name in ['log_normalizer',
'to_exp',
'carrier_measure',
'sufficient_statistics',
'cross_entropy',
'expected_carrier_measure']:
super_cls = super(cls, cls)
if not hasattr(cls, name):
continue
original_method = getattr(cls, name)
if hasattr(super_cls, name) and getattr(super_cls, name) is original_method:
continue # We only need to jit new methods.
method = jit(original_method)
setattr(cls, f'_original_{name}', method)

if name != 'log_normalizer':
setattr(cls, name, method)
continue

method_jvp: Any = custom_jvp(method)

def ln_jvp(primals: Tuple[NaturalParametrization[Any]],
tangents: Tuple[NaturalParametrization[Any]]) -> Tuple[RealArray, RealArray]:
q, = primals
q_dot, = tangents
y = q.log_normalizer()
p = q.to_exp()
y_dot = tree_dot_final(q_dot, p)
return y, y_dot

method_jvp.defjvp(ln_jvp)

setattr(cls, name, method_jvp)

# New methods ----------------------------------------------------------------------------------
def flattened(self) -> Array:
def flatten_parameter(x: Array) -> Array:
return jnp.reshape(x, (*self.shape(), -1))
return tree_reduce(partial(jnp.append, axis=-1), tree_map(flatten_parameter, self))

@classmethod
def unflattened(cls: Type[T], flattened: Array, **kwargs: Any) -> T:
# Count the fields with 0, 1, and 2 axes. Subtract the shape of flattened from the 0 count.
totals = np.zeros(3, dtype=np.int_)
totals[0] -= flattened.shape[-1]
for _, n_axes in parameter_names_axes(cls):
if not 0 <= n_axes <= 2:
raise ValueError
totals[n_axes] += 1

# Solve the quadratic equation and select the largest positive root.
roots = np.roots(list(reversed(list(totals))))
roots = list(roots)
if not roots:
root = 1
else:
root = int(max(roots))
if root < 0:
raise ValueError

# Unflatten.
shape = flattened.shape[:-1]
consumed = 0
for name, n_axes in parameter_names_axes(cls):
k = root ** n_axes
kwargs[name] = np.reshape(flattened[..., consumed: consumed + k],
shape + (root,) * n_axes)
consumed += k

return cls(**kwargs) # type: ignore

# Abstract methods -----------------------------------------------------------------------------
def shape(self) -> Shape:
raise NotImplementedError


if TYPE_CHECKING:
from .exponential_family import NaturalParametrization
38 changes: 28 additions & 10 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ python = "^3.8"
chex = "^0.0.2"
jax = "^0.2"
jaxlib = "^0.1.55"
numpy = "^1.18"
numpy = ">=1.20.0rc2,<1.21"
scipy = "^1.4"
tjax = ">=0.7.5,<1.0"

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def configure_numpy() -> Generator[None, None, None]:

@pytest.fixture(scope='session',
params=create_infos())
def distribution_info(request: Any) -> List[DistributionInfo]:
def distribution_info(request: Any) -> List[DistributionInfo[Any, Any]]:
return request.param
Loading

0 comments on commit fe4eb9d

Please sign in to comment.