From 97c960544b163c1b146c215669a2c770871ef54b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Sat, 10 Jun 2023 18:34:32 +0200 Subject: [PATCH] Experimenting with change of variable for minibatch hmc --- mess.py | 129 ++++++++++++++++++++++----------- tweetopic/bayesian/dmm.py | 12 ++- tweetopic/bayesian/sampling.py | 5 +- 3 files changed, 95 insertions(+), 51 deletions(-) diff --git a/mess.py b/mess.py index 7a72520..3d544f4 100644 --- a/mess.py +++ b/mess.py @@ -5,6 +5,7 @@ import random from functools import partial +import blackjax import jax import jax.numpy as jnp import numpy as np @@ -13,65 +14,107 @@ import scipy.stats as stats from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics import r2_score +from tqdm import trange from tweetopic._doc import init_doc_words -from tweetopic.bayesian.dmm import ( - BayesianDMM, - posterior_predictive, - predict_doc, - sparse_multinomial_logpdf, -) -from tweetopic.sampling import ( - sample_meanfield_vi, - sample_nuts, - sample_pathfinder, - sample_sgld, -) - -texts = [line for line in open("processed_sample.txt")] +from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive, + predict_doc, sparse_multinomial_logpdf, + symmetric_dirichlet_logpdf, + symmetric_dirichlet_multinomial_logpdf) +from tweetopic.bayesian.sampling import batch_data, sample_nuts +from tweetopic.func import spread -vectorizer = CountVectorizer(max_features=4000, max_df=0.3, min_df=10) -X = vectorizer.fit_transform(random.sample(texts, 20_000)) - -model = BayesianDMM( - n_components=5, - alpha=1.0, - beta=1.0, - sampler=partial(sample_sgld, n_samples=2000), -) -model.fit(X) +alpha = 0.2 +n_features = 10 +n_docs = 1000 -X = X[X.getnnz(1) > 0] +doc_lengths = np.random.randint(10, 100, size=n_docs) +components = stats.dirichlet.rvs(alpha=np.full(n_features, alpha)) +X = np.stack([stats.multinomial.rvs(n, components[0]) for n in doc_lengths]) X = spr.csr_matrix(X) +X = X[X.getnnz(1) > 0] +n_documents, n_features_in_ = X.shape max_unique_words = np.max(np.diff(X.indptr)) doc_unique_words, doc_unique_word_counts = init_doc_words( X.tolil(), max_unique_words=max_unique_words, ) +data = dict( + doc_unique_words=doc_unique_words, + doc_unique_word_counts=doc_unique_word_counts, +) -components = np.array([sample["components"] for sample in model.samples]) -weights = np.array([sample["weights"] for sample in model.samples]) -pred = posterior_predictive( - doc_unique_words, doc_unique_word_counts, components, weights -) +def transform(component): + component = jnp.square(component) + component = component / jnp.sum(component) + return component + -px.box(pred[4]).show() +def logprior_fn(params): + component = transform(params["component"]) + return symmetric_dirichlet_logpdf(component, alpha=alpha) -pred[0] -px.line(weights).show() +def loglikelihood_fn(params, data): + doc_likelihood = jax.vmap( + partial(sparse_multinomial_logpdf, component=params["component"]) + ) + return jnp.sum( + doc_likelihood( + unique_words=data["doc_unique_words"], + unique_word_counts=data["doc_unique_word_counts"], + ) + ) -X.shape -predict_doc( - unique_words=doc_unique_words[0], - unique_word_counts=doc_unique_word_counts[0], - components=components[0], - weights=weights[0], +logdensity_fn(position) + +logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn( + params, data +) +grad_estimator = blackjax.sgmcmc.gradients.grad_estimator( + logprior_fn, loglikelihood_fn, data_size=n_documents +) +rng_key = jax.random.PRNGKey(0) +batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3) +batch_idx = batch_data(batch_key, batch_size=64, data_size=n_documents) +batches = ( + dict( + doc_unique_words=doc_unique_words[idx], + doc_unique_word_counts=doc_unique_word_counts[idx], + ) + for idx in batch_idx ) +position = dict( + component=jnp.array( + transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))) + ) +) + +samples, states = sample_nuts(position, logdensity_fn) + + +rng_key = jax.random.PRNGKey(0) +n_samples = 4000 +sghmc = blackjax.sgld(grad_estimator) # , num_integration_steps=10) +states = [] +step_size = 1e-8 +samples = [] +for i in trange(n_samples, desc="Sampling"): + _, rng_key = jax.random.split(rng_key) + minibatch = next(batches) + position = jax.jit(sghmc)(rng_key, position, minibatch, step_size) + samples.append(position) + +densities = [jax.jit(logdensity_fn)(sample) for sample in samples] +component_trace = jnp.stack([sample["component"] for sample in samples]) +component_trace = jax.vmap(transform)(component_trace) +px.line(component_trace).show() + +for i, density in enumerate(densities): + if np.array(density) != -np.inf: + print(f"{i}: {density}") + -try: - predict_one_doc(doc=docs[0], samples=np.array(model.samples[:2])) -except Exception: - print("oopsie") +px.line(densities).show() diff --git a/tweetopic/bayesian/dmm.py b/tweetopic/bayesian/dmm.py index 904e0f0..10cb4a5 100644 --- a/tweetopic/bayesian/dmm.py +++ b/tweetopic/bayesian/dmm.py @@ -60,10 +60,14 @@ def sparse_multinomial_logpdf( def symmetric_dirichlet_logpdf(x, alpha): """Logdensity of a symmetric Dirichlet.""" K = x.shape[0] + sums_to_one = jnp.abs(1 - jnp.sum(x)) <= 0.001 + all_bigger_than_zero = jnp.all(x >= 0) return ( - jax.lax.lgamma(alpha * K) + jnp.log(sums_to_one) + + jnp.log(all_bigger_than_zero) + + jax.lax.lgamma(alpha * K) - K * jax.lax.lgamma(alpha) - + (alpha - 1) * jnp.sum(jnp.log(x)) + + (alpha - 1) * jnp.sum(jnp.nan_to_num(jnp.log(x))) ) @@ -223,10 +227,10 @@ def set_params(self, **params): return self def fit(self, X, y=None): - # Filtering out empty documents - X = X[X.getnnz(1) > 0] # Converting X into sparse array if it isn't one already. X = spr.csr_matrix(X) + # Filtering out empty documents + X = X[X.getnnz(1) > 0] self.n_documents, self.n_features_in_ = X.shape # Calculating the number of nonzero elements for each row # using the internal properties of CSR matrices. diff --git a/tweetopic/bayesian/sampling.py b/tweetopic/bayesian/sampling.py index cfcb2d5..126ba84 100644 --- a/tweetopic/bayesian/sampling.py +++ b/tweetopic/bayesian/sampling.py @@ -19,11 +19,9 @@ def sample_nuts( initial_position: PyTree, logdensity_fn: Callable, - data: dict, seed: int = 0, n_warmup: int = 1000, n_samples: int = 1000, - data_axis: int = 0, ) -> tuple[list[PyTree], list[HMCState]]: """NUTS sampling loop for any logdensity function that can be JIT compiled with JAX. @@ -50,7 +48,6 @@ def sample_nuts( State of the Hamiltonian Monte Carlo at each step. Mostly useful for debugging. """ - logdensity_fn = spread(partial(logdensity_fn, **data)) rng_key = jax.random.PRNGKey(seed) warmup_key, sampling_key = jax.random.split(rng_key) print("Warmup, window adaptation") @@ -173,7 +170,7 @@ def sample_meanfield_vi( return samples, states -def batch_data(rng_key, data, batch_size: int, data_size: int): +def batch_data(rng_key, batch_size: int, data_size: int): while True: _, rng_key = jax.random.split(rng_key) idx = jax.random.choice(