diff --git a/mess.py b/mess.py index 26a8d4c..7a72520 100644 --- a/mess.py +++ b/mess.py @@ -30,8 +30,8 @@ texts = [line for line in open("processed_sample.txt")] -vectorizer = CountVectorizer(max_features=100, max_df=0.3, min_df=10) -X = vectorizer.fit_transform(random.sample(texts, 10_000)) +vectorizer = CountVectorizer(max_features=4000, max_df=0.3, min_df=10) +X = vectorizer.fit_transform(random.sample(texts, 20_000)) model = BayesianDMM( n_components=5, @@ -64,7 +64,12 @@ X.shape -predict_doc() +predict_doc( + unique_words=doc_unique_words[0], + unique_word_counts=doc_unique_word_counts[0], + components=components[0], + weights=weights[0], +) try: predict_one_doc(doc=docs[0], samples=np.array(model.samples[:2])) diff --git a/tweetopic/bayesian/dmm.py b/tweetopic/bayesian/dmm.py index c75b51d..904e0f0 100644 --- a/tweetopic/bayesian/dmm.py +++ b/tweetopic/bayesian/dmm.py @@ -118,12 +118,10 @@ def posterior_predictive( return predict_all(doc_unique_words, doc_unique_word_counts) -def dirichlet_multinomial_mixture_logpdf( +def dmm_loglikelihood( components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta ): - """Calculates logdensity of the DMM at a given point in parameter space.""" docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1) - n_docs = doc_unique_words.shape[0] def doc_likelihood(doc): unique_words, unique_word_counts = doc @@ -137,7 +135,11 @@ def doc_likelihood(doc): ) + jnp.log(weights) return jax.scipy.special.logsumexp(component_logprobs) - likelihood = jnp.sum(jax.lax.map(doc_likelihood, docs)) + loglikelihood = jnp.sum(jax.lax.map(doc_likelihood, docs)) + return loglikelihood + + +def dmm_logprior(components, weights, alpha, beta, n_docs): components_prior = jnp.sum( jax.lax.map( partial(symmetric_dirichlet_logpdf, alpha=alpha), components @@ -146,7 +148,24 @@ def doc_likelihood(doc): weights_prior = symmetric_dirichlet_multinomial_logpdf( weights, n=jnp.float64(n_docs), alpha=beta ) - return likelihood + components_prior + weights_prior + return components_prior + weights_prior + + +def dmm_logpdf( + components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta +): + """Calculates logdensity of the DMM at a given point in parameter space.""" + n_docs = doc_unique_words.shape[0] + loglikelihood = dmm_loglikelihood( + components, + weights, + doc_unique_words, + doc_unique_word_counts, + alpha, + beta, + ) + logprior = dmm_logprior(components, weights, alpha, beta, n_docs) + return logprior + loglikelihood class BayesianDMM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator): @@ -223,16 +242,21 @@ def fit(self, X, y=None): alpha=self.alpha, beta=self.beta, ) - logdensity_fn = spread( - partial( - dirichlet_multinomial_mixture_logpdf, + logdensity_fn = partial( + dmm_logpdf, + doc_unique_words=doc_unique_words, + doc_unique_word_counts=doc_unique_word_counts, + alpha=self.alpha, + beta=self.beta, + ) + samples = self.sampler( + initial_position, + logdensity_fn, + data=dict( doc_unique_words=doc_unique_words, doc_unique_word_counts=doc_unique_word_counts, - alpha=self.alpha, - beta=self.beta, - ) + ), ) - samples = self.sampler(initial_position, logdensity_fn) self.samples = samples return self diff --git a/tweetopic/bayesian/sampling.py b/tweetopic/bayesian/sampling.py index b402335..cfcb2d5 100644 --- a/tweetopic/bayesian/sampling.py +++ b/tweetopic/bayesian/sampling.py @@ -1,23 +1,29 @@ """Sampling utilities using jax and blackjax.""" +from functools import partial from typing import Any, Callable import blackjax import jax +import jax.numpy as jnp from blackjax.mcmc.hmc import HMCState from blackjax.types import PyTree from blackjax.vi.meanfield_vi import MFVIState from optax import adam from tqdm import trange +from tweetopic.func import spread + Sampler = Callable[..., tuple[list[PyTree], Any]] def sample_nuts( initial_position: PyTree, logdensity_fn: Callable, + data: dict, seed: int = 0, n_warmup: int = 1000, n_samples: int = 1000, + data_axis: int = 0, ) -> tuple[list[PyTree], list[HMCState]]: """NUTS sampling loop for any logdensity function that can be JIT compiled with JAX. @@ -44,6 +50,7 @@ def sample_nuts( State of the Hamiltonian Monte Carlo at each step. Mostly useful for debugging. """ + logdensity_fn = spread(partial(logdensity_fn, **data)) rng_key = jax.random.PRNGKey(seed) warmup_key, sampling_key = jax.random.split(rng_key) print("Warmup, window adaptation") @@ -64,10 +71,12 @@ def sample_nuts( def sample_sgld( initial_position: dict, logdensity_fn: Callable, + data: dict, seed: int = 0, n_samples: int = 1000, initial_step_size: float = 1000, decay: float = 2.5, + data_axis: int = 0, ) -> tuple[list[PyTree], None]: """Stochastic Gradient Langevin Dynamics sampling loop with decaying step size for any logdensity function that is differentiable with JAX. @@ -98,6 +107,7 @@ def sample_sgld( states: None Ignored only returned for compatibilty. """ + logdensity_fn = spread(partial(logdensity_fn, **data)) rng_key = jax.random.PRNGKey(seed) num_training_steps = n_samples schedule_fn = lambda k: initial_step_size * (k ** (-decay)) @@ -116,9 +126,12 @@ def sample_sgld( def sample_pathfinder( initial_position: dict, logdensity_fn: Callable, + data: dict, seed: int = 0, n_samples: int = 1000, + data_axis: int = 0, ) -> list[PyTree]: + logdensity_fn = spread(partial(logdensity_fn, **data)) rng_key = jax.random.PRNGKey(seed) optim_key, sampling_key = jax.random.split(rng_key) pathfinder = blackjax.pathfinder(logdensity_fn) @@ -134,12 +147,15 @@ def sample_pathfinder( def sample_meanfield_vi( initial_position: dict, logdensity_fn: Callable, + data: dict, seed: int = 0, n_iter: int = 20_000, n_samples: int = 1000, n_optim_samples: int = 20, learning_rate: float = 0.08, + data_axis: int = 0, ) -> tuple[list[PyTree], list[MFVIState]]: + logdensity_fn = spread(partial(logdensity_fn, **data)) rng_key = jax.random.PRNGKey(seed) optim_key, sampling_key = jax.random.split(rng_key) optimizer = adam(learning_rate) @@ -155,3 +171,55 @@ def sample_meanfield_vi( states.append(state) samples = mfvi.sample(sampling_key, state, num_samples=n_samples) return samples, states + + +def batch_data(rng_key, data, batch_size: int, data_size: int): + while True: + _, rng_key = jax.random.split(rng_key) + idx = jax.random.choice( + key=rng_key, a=jnp.arange(data_size), shape=(batch_size,) + ) + yield idx + + +def get_batch(idx, data: dict, data_axis: int): + return { + key: jnp.take(value, idx, axis=data_axis) + for key, value in data.items() + } + + +# TODO +def sample_minibatch_hmc( + initial_position: PyTree, + logdensity_fn: Callable, + data: dict, + seed: int = 0, + batch_size: int = 512, + step_size: float = 0.001, + n_warmup: int = 100, + n_samples: int = 1000, + data_axis=0, +) -> tuple[list[PyTree], list[HMCState]]: + rng_key = jax.random.PRNGKey(seed) + batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3) + batches = batch_data( + batch_key, batch_size, data_size=len(data[list(data.keys())[0]]) + ) + print("Warmup, window adaptation") + warmup_batch = get_batch(next(batches), data, data_axis=data_axis) + warmup = blackjax.window_adaptation( + blackjax.hmc, partial(logdensity_fn, **warmup_batch) + ) + (state, parameters), _ = warmup.run( + warmup_key, initial_position, num_steps=n_warmup + ) + sghmc = blackjax.sghmc() + kernel = jax.jit(blackjax.nuts(logdensity_fn, **parameters).step) + states = [] + for i in trange(n_samples, desc="Sampling"): + _, sampling_key = jax.random.split(sampling_key) + state, _ = kernel(sampling_key, state) + states.append(state) + samples = [state.position for state in states] + return samples, states