Skip to content

Latest commit

 

History

History
183 lines (133 loc) · 7.78 KB

README.md

File metadata and controls

183 lines (133 loc) · 7.78 KB

Cross-framework Python Package for Evaluation of Latent-based Generative Models

Documentation Status CircleCI codecov CodeFactor License PyPI version DOI arXiv

Latte

Latte (for LATent Tensor Evaluation) is a cross-framework Python package for evaluation of latent-based generative models. Latte supports calculation of disentanglement and controllability metrics in both PyTorch (via TorchMetrics) and TensorFlow.

Installation

For developers working on local clone, cd to the repo and replace latte with .. For example, pip install .[tests]

pip install latte-metrics           # core (numpy only)
pip install latte-metrics[pytorch]  # with torchmetrics wrapper
pip install latte-metrics[keras]    # with tensorflow wrapper
pip install latte-metrics[tests]    # for testing

Running tests locally

pip install .[tests]
pytest tests/ --cov=latte

Quick Examples

Functional API

import latte
from latte.functional.disentanglement.mutual_info import mig
import numpy as np

latte.seed(42)

z = np.random.randn(16, 8)
a = np.random.randn(16, 2)

mutual_info_gap = mig(z, a, discrete=False, reg_dim=[4, 3])

Modular API

import latte
from latte.metrics.core.disentanglement import MutualInformationGap
import numpy as np

latte.seed(42)

mig = MutualInformationGap()

# ... 
# initialize data and model
# ...

for data, attributes in range(batches):
  recon, z = model(data)

  mig.update_state(z, attributes)

mig_val = mig.compute()

TorchMetrics API

import latte
from latte.metrics.torch.disentanglement import MutualInformationGap
import torch

latte.seed(42)

mig = MutualInformationGap()

# ... 
# initialize data and model
# ...

for data, attributes in range(batches):
  recon, z = model(data)

  mig.update(z, attributes)

mig_val = mig.compute()

Keras Metric API

import latte
from latte.metrics.keras.disentanglement import MutualInformationGap
from tensorflow import keras as tfk

latte.seed(42)

mig = MutualInformationGap()

# ... 
# initialize data and model
# ...

for data, attributes in range(batches):
  recon, z = model(data)

  mig.update_state(z, attributes)

mig_val = mig.result()

Example Notebooks

See Latte in action with Morpho-MNIST example notebooks on Google Colab:

Documentation

https://latte.readthedocs.io/en/latest

Supported metrics

🧪 Beta support | ✔️ Stable | 🔨 In Progress | 🕣 In Queue | 👀 KIV |

Metric Latte Functional Latte Modular TorchMetrics Keras Metric
Disentanglement Metrics
📝 Mutual Information Gap (MIG) 🧪 🧪 🧪 🧪
📝 Dependency-blind Mutual Information Gap (DMIG) 🧪 🧪 🧪 🧪
📝 Dependency-aware Mutual Information Gap (XMIG) 🧪 🧪 🧪 🧪
📝 Dependency-aware Latent Information Gap (DLIG) 🧪 🧪 🧪 🧪
📝 Separate Attribute Predictability (SAP) 🧪 🧪 🧪 🧪
📝 Modularity 🧪 🧪 🧪 🧪
📝 β-VAE Score 👀 👀 👀 👀
📝 FactorVAE Score 👀 👀 👀 👀
📝 DCI Score 👀 👀 👀 👀
📝 Interventional Robustness Score (IRS) 👀 👀 👀 👀
📝 Consistency 👀 👀 👀 👀
📝 Restrictiveness 👀 👀 👀 👀
Interpolatability Metrics
📝 Smoothness 🧪 🧪 🧪 🧪
📝 Monotonicity 🧪 🧪 🧪 🧪
📝 Latent Density Ratio 🕣 🕣 🕣 🕣
📝 Linearity 👀 👀 👀 👀

Bundled metric modules

🧪 Experimental (subject to changes) | ✔️ Stable | 🔨 In Progress | 🕣 In Queue

Metric Bundle Latte Functional Latte Modular TorchMetrics Keras Metric Included
Dependency-aware Disentanglement 🧪 🧪 🧪 🧪 MIG, DMIG, XMIG, DLIG
LIAD-based Interpolatability 🧪 🧪 🧪 🧪 Smoothness, Monotonicity

Cite

For individual metrics, please cite the paper according to the link in the 📝 icon in front of each metric.

If you find our package useful, please cite open access paper on Software Impacts (Elsevier) as

@article{
  watcharasupat2021latte,
  author = {Watcharasupat, Karn N. and Lee, Junyoung and Lerch, Alexander},
  title = {{Latte: Cross-framework Python Package for Evaluation of Latent-based Generative Models}},
  journal = {Software Impacts},
  volume = {11},
  pages = {100222},
  year = {2022},
  issn = {2665-9638},
  doi = {https://doi.org/10.1016/j.simpa.2022.100222},
  url = {https://www.sciencedirect.com/science/article/pii/S2665963822000033},
}