Skip to content

Commit

Permalink
Update the experiments (#23)
Browse files Browse the repository at this point in the history
* Fix a bug in the unrestricted ratio estimator.

* Replace PyMC with NumPyro

* Add experiment with nearly-nonidentifiable model.

* Add workflow with misspecified model

* Minor plot improvements

* Add benchmark implementation

* Minor adjustments

* Basic utilities for single-cell section

* Plot adjustments

* Update misspecified plots to use subplots_from_axsize

* Finish single-cell workflow
  • Loading branch information
pawel-czyz committed Mar 13, 2024
1 parent acc8677 commit e2c5174
Show file tree
Hide file tree
Showing 11 changed files with 1,396 additions and 232 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Directories for local files
local/
private/
generated/

# Data
*.json
Expand Down
5 changes: 3 additions & 2 deletions labelshift/algorithms/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
>>> import labelshift.algorithms.api as algo
"""
from labelshift.algorithms.bayesian_discrete import DiscreteCategoricalMAPEstimator
from labelshift.algorithms.bayesian_discrete import DiscreteCategoricalMeanEstimator, SamplingParams
from labelshift.algorithms.bbse import BlackBoxShiftEstimator
from labelshift.algorithms.classify_and_count import ClassifyAndCount
from labelshift.algorithms.ratio_estimator import InvariantRatioEstimator
Expand All @@ -13,7 +13,8 @@
__all__ = [
"BlackBoxShiftEstimator",
"ClassifyAndCount",
"DiscreteCategoricalMAPEstimator",
"DiscreteCategoricalMeanEstimator",
"InvariantRatioEstimator",
"SummaryStatistic",
"SamplingParams",
]
197 changes: 49 additions & 148 deletions labelshift/algorithms/bayesian_discrete.py
Original file line number Diff line number Diff line change
@@ -1,178 +1,79 @@
"""Categorical discrete Bayesian model for quantification.
Proposed in
TODO(Pawel): Add citation to pre-print after AISTATS reviews.
"""
from typing import cast, NewType, Optional, Union

import arviz as az
"""Categorical discrete Bayesian model for quantification."""
import numpy as np
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import jax
import pydantic
import pymc
import pymc as pm
from typing import Optional

from numpy.typing import ArrayLike

import labelshift.interfaces.point_estimators as pe


P_TRAIN_Y: str = "P_train(Y)"
P_TEST_Y: str = "P_test(Y)"
P_TEST_C: str = "P_test(C)"
P_C_COND_Y: str = "P(C|Y)"


class SamplingParams(pydantic.BaseModel):
"""Settings for the MCMC sampler."""

draws: pydantic.PositiveInt = pydantic.Field(default=1000)
chains: pydantic.PositiveInt = pydantic.Field(default=4)
random_seed: int = 20


DiscreteBayesianQuantificationModel = NewType(
"DiscreteBayesianQuantificationModel", pm.Model
)


def dirichlet_alphas(L: int, alpha: Union[float, ArrayLike]) -> np.ndarray:
"""Convenient initialization of alpha (pseudocounts)
parameters of the Dirichlet prior.
Args:
alpha: either an array of shape (L,) or a float.
If a float, vector (alpha, alpha, ..., alpha)
is created
Returns:
alphas, shape (L,)
"""
if isinstance(alpha, float):
return np.ones(L) * alpha
else:
alpha = np.asarray(alpha)
assert alpha.shape == (L,)
return alpha


def build_model(
n_y_and_c_labeled: ArrayLike,
n_c_unlabeled: ArrayLike,
alpha_p_y_labeled: Union[float, ArrayLike] = 1.0,
alpha_p_y_unlabeled: Union[float, ArrayLike] = 1.0,
) -> DiscreteBayesianQuantificationModel:
"""Builds the discrete Bayesian quantification model,
basing on the sufficient statistic of the data.
Args:
n_y_and_c_labeled: histogram of Y and C labels in the labeled data set,
shape (L, K)
n_c_unlabeled: histogram of C in the unlabeled data set, shape (K,)
"""
n_y_and_c_labeled = np.asarray(n_y_and_c_labeled)
n_y_labeled = n_y_and_c_labeled.sum(axis=1)
n_c_unlabeled = np.asarray(n_c_unlabeled)

assert len(n_y_and_c_labeled.shape) == 2
L, K = n_y_and_c_labeled.shape

assert n_y_labeled.shape == (L,)
assert n_c_unlabeled.shape == (K,)
warmup: pydantic.PositiveInt = pydantic.Field(default=500)
samples: pydantic.PositiveInt = pydantic.Field(default=1000)

alpha_p_y_labeled = dirichlet_alphas(L, alpha_p_y_labeled)
alpha_p_y_unlabeled = dirichlet_alphas(L, alpha_p_y_unlabeled)

model = pm.Model()
with model:
# Prior on pi, pi_, phi
pi = pm.Dirichlet(P_TRAIN_Y, alpha_p_y_labeled)
pi_ = pm.Dirichlet(P_TEST_Y, alpha_p_y_unlabeled)
p_c_cond_y = pm.Dirichlet(P_C_COND_Y, np.ones(K), shape=(L, K))

# Note: we need to silence unused variable error (F841)

# Sample N_y from P_train(Y)
N_y = pm.Multinomial( # noqa: F841
"N_y", np.sum(n_y_labeled), p=pi, observed=n_y_labeled
)

# Sample the rows
F_yc = pm.Multinomial( # noqa: F841
"F_yc", n_y_labeled, p=p_c_cond_y, observed=n_y_and_c_labeled
)

# Sample from P_test(C) = P(C | Y) P_test(Y)
p_c = pm.Deterministic(P_TEST_C, p_c_cond_y.T @ pi_)
N_c = pm.Multinomial( # noqa: F841
"N_c", np.sum(n_c_unlabeled), p=p_c, observed=n_c_unlabeled
)

return cast(DiscreteBayesianQuantificationModel, model)
P_TRAIN_Y: str = "P_train(Y)"
P_TEST_Y: str = "P_test(Y)"
P_TEST_C: str = "P_test(C)"
P_C_COND_Y: str = "P(C|Y)"


def sample_from_bayesian_discrete_model_posterior(
model: DiscreteBayesianQuantificationModel,
sampling_params: Optional[SamplingParams] = None,
) -> az.InferenceData:
"""Inference in the Bayesian model
def model(summary_statistic):
n_y_labeled = summary_statistic.n_y_labeled
n_y_and_c_labeled = summary_statistic.n_y_and_c_labeled
n_c_unlabeled = summary_statistic.n_c_unlabeled
K = len(n_c_unlabeled)
L = len(n_y_labeled)

Args:
model: built model
sampling_params: sampling parameters, will be passed to PyMC's sampling method
"""
sampling_params = SamplingParams() if sampling_params is None else sampling_params
pi = numpyro.sample(P_TRAIN_Y, dist.Dirichlet(jnp.ones(L)))
pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L)))
p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K)))

with model:
inference_data = pm.sample(
random_seed=sampling_params.random_seed,
chains=sampling_params.chains,
draws=sampling_params.draws,
)
N_y = numpyro.sample('N_y', dist.Multinomial(jnp.sum(n_y_labeled), pi), obs=n_y_labeled)

with numpyro.plate('plate', L):
numpyro.sample('F_yc', dist.Multinomial(N_y, p_c_cond_y), obs=n_y_and_c_labeled)

return inference_data
p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_))
numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)


class DiscreteCategoricalMeanEstimator(pe.SummaryStatisticPrevalenceEstimator):
"""A version of Bayesian quantification which finds the mean solution.
Note that it runs the MCMC sampler in the backend.
"""

def __init__(self) -> None:
"""Not implemented yet."""
raise NotImplementedError
P_TRAIN_Y = P_TRAIN_Y
P_TEST_Y = P_TEST_Y
P_TEST_C = P_TEST_C
P_C_COND_Y = P_C_COND_Y

def __init__(self, params: Optional[SamplingParams] = None, seed: int = 42) -> None:
if params is None:
params = SamplingParams()
self._params = params
self._seed = seed

def sample_posterior(self, /, statistic: pe.SummaryStatistic):
"""Returns the samples from the MCMC sampler."""
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(model),
num_warmup=self._params.warmup,
num_samples=self._params.samples)
rng_key = jax.random.PRNGKey(self._seed)
mcmc.run(rng_key, summary_statistic=statistic)
return mcmc.get_samples()

def estimate_from_summary_statistic(
self, /, statistic: pe.SummaryStatistic
) -> np.ndarray:
"""Returns the mean prediction."""
raise NotImplementedError


class DiscreteCategoricalMAPEstimator(pe.SummaryStatisticPrevalenceEstimator):
"""A version of Bayesian quantification
which finds the Maximum a Posteriori solution."""

def __init__(
self, max_eval: int = 10_000, alpha_unlabeled: Union[float, ArrayLike] = 1.0
) -> None:
"""
Args:
max_eval: maximal number of evaluations of the posterior
during the optimization to find the MAP
"""
self._max_eval = max_eval
self._alpha_unlabeled = alpha_unlabeled

def estimate_from_summary_statistic(
self, /, statistic: pe.SummaryStatistic
) -> np.ndarray:
"""Finds the Maximum a Posteriori (MAP)."""
model = build_model(
n_c_unlabeled=statistic.n_c_unlabeled,
n_y_and_c_labeled=statistic.n_y_and_c_labeled,
alpha_p_y_unlabeled=self._alpha_unlabeled,
)
with model:
optimal = pymc.find_MAP(maxeval=self._max_eval)
return optimal[P_TEST_Y]
samples = self.sample_posterior(statistic)[P_TEST_Y]
return np.array(samples.mean(axis=0))
80 changes: 0 additions & 80 deletions labelshift/algorithms/gaussian_mixture_model.py

This file was deleted.

1 change: 1 addition & 0 deletions labelshift/algorithms/ratio_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,5 @@ def estimate_from_summary_statistic(
n_y_and_c_labeled=statistic.n_y_and_c_labeled,
enforce_square=self._enforce_square,
rcond=self._rcond,
restricted=self._restricted,
)
7 changes: 7 additions & 0 deletions labelshift/datasets/discrete_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

@dataclasses.dataclass
class SummaryMultinomialStatistic:
"""
Attributes:
n_y: shape (L,)
n_c: shape (K,)
n_y_and_c: shape (L, K)
"""
n_y: np.ndarray
n_c: np.ndarray
n_y_and_c: np.ndarray
Expand Down
8 changes: 6 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ setup_requires =
install_requires =
arviz
numpy >= 1.12,<2
petname
pydantic
pymc
jax
jaxlib
joblib
matplotlib
numpyro
scikit-learn
scipy
subplots_from_axsize

[options.packages.find]
where = labelshift
Expand Down
Loading

0 comments on commit e2c5174

Please sign in to comment.