Skip to content

Commit

Permalink
Experimenting with change of variable for minibatch hmc
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Jun 10, 2023
1 parent 27dd112 commit 97c9605
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 51 deletions.
129 changes: 86 additions & 43 deletions mess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
from functools import partial

import blackjax
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -13,65 +14,107 @@
import scipy.stats as stats
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import r2_score
from tqdm import trange

from tweetopic._doc import init_doc_words
from tweetopic.bayesian.dmm import (
BayesianDMM,
posterior_predictive,
predict_doc,
sparse_multinomial_logpdf,
)
from tweetopic.sampling import (
sample_meanfield_vi,
sample_nuts,
sample_pathfinder,
sample_sgld,
)

texts = [line for line in open("processed_sample.txt")]
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
predict_doc, sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf)
from tweetopic.bayesian.sampling import batch_data, sample_nuts
from tweetopic.func import spread

vectorizer = CountVectorizer(max_features=4000, max_df=0.3, min_df=10)
X = vectorizer.fit_transform(random.sample(texts, 20_000))

model = BayesianDMM(
n_components=5,
alpha=1.0,
beta=1.0,
sampler=partial(sample_sgld, n_samples=2000),
)
model.fit(X)
alpha = 0.2
n_features = 10
n_docs = 1000

X = X[X.getnnz(1) > 0]
doc_lengths = np.random.randint(10, 100, size=n_docs)
components = stats.dirichlet.rvs(alpha=np.full(n_features, alpha))
X = np.stack([stats.multinomial.rvs(n, components[0]) for n in doc_lengths])
X = spr.csr_matrix(X)
X = X[X.getnnz(1) > 0]
n_documents, n_features_in_ = X.shape
max_unique_words = np.max(np.diff(X.indptr))
doc_unique_words, doc_unique_word_counts = init_doc_words(
X.tolil(),
max_unique_words=max_unique_words,
)
data = dict(
doc_unique_words=doc_unique_words,
doc_unique_word_counts=doc_unique_word_counts,
)

components = np.array([sample["components"] for sample in model.samples])
weights = np.array([sample["weights"] for sample in model.samples])

pred = posterior_predictive(
doc_unique_words, doc_unique_word_counts, components, weights
)
def transform(component):
component = jnp.square(component)
component = component / jnp.sum(component)
return component


px.box(pred[4]).show()
def logprior_fn(params):
component = transform(params["component"])
return symmetric_dirichlet_logpdf(component, alpha=alpha)

pred[0]

px.line(weights).show()
def loglikelihood_fn(params, data):
doc_likelihood = jax.vmap(
partial(sparse_multinomial_logpdf, component=params["component"])
)
return jnp.sum(
doc_likelihood(
unique_words=data["doc_unique_words"],
unique_word_counts=data["doc_unique_word_counts"],
)
)

X.shape

predict_doc(
unique_words=doc_unique_words[0],
unique_word_counts=doc_unique_word_counts[0],
components=components[0],
weights=weights[0],
logdensity_fn(position)

logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
params, data
)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
logprior_fn, loglikelihood_fn, data_size=n_documents
)
rng_key = jax.random.PRNGKey(0)
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
batch_idx = batch_data(batch_key, batch_size=64, data_size=n_documents)
batches = (
dict(
doc_unique_words=doc_unique_words[idx],
doc_unique_word_counts=doc_unique_word_counts[idx],
)
for idx in batch_idx
)
position = dict(
component=jnp.array(
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
)
)

samples, states = sample_nuts(position, logdensity_fn)


rng_key = jax.random.PRNGKey(0)
n_samples = 4000
sghmc = blackjax.sgld(grad_estimator) # , num_integration_steps=10)
states = []
step_size = 1e-8
samples = []
for i in trange(n_samples, desc="Sampling"):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
position = jax.jit(sghmc)(rng_key, position, minibatch, step_size)
samples.append(position)

densities = [jax.jit(logdensity_fn)(sample) for sample in samples]
component_trace = jnp.stack([sample["component"] for sample in samples])
component_trace = jax.vmap(transform)(component_trace)
px.line(component_trace).show()

for i, density in enumerate(densities):
if np.array(density) != -np.inf:
print(f"{i}: {density}")


try:
predict_one_doc(doc=docs[0], samples=np.array(model.samples[:2]))
except Exception:
print("oopsie")
px.line(densities).show()
12 changes: 8 additions & 4 deletions tweetopic/bayesian/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,14 @@ def sparse_multinomial_logpdf(
def symmetric_dirichlet_logpdf(x, alpha):
"""Logdensity of a symmetric Dirichlet."""
K = x.shape[0]
sums_to_one = jnp.abs(1 - jnp.sum(x)) <= 0.001
all_bigger_than_zero = jnp.all(x >= 0)
return (
jax.lax.lgamma(alpha * K)
jnp.log(sums_to_one)
+ jnp.log(all_bigger_than_zero)
+ jax.lax.lgamma(alpha * K)
- K * jax.lax.lgamma(alpha)
+ (alpha - 1) * jnp.sum(jnp.log(x))
+ (alpha - 1) * jnp.sum(jnp.nan_to_num(jnp.log(x)))
)


Expand Down Expand Up @@ -223,10 +227,10 @@ def set_params(self, **params):
return self

def fit(self, X, y=None):
# Filtering out empty documents
X = X[X.getnnz(1) > 0]
# Converting X into sparse array if it isn't one already.
X = spr.csr_matrix(X)
# Filtering out empty documents
X = X[X.getnnz(1) > 0]
self.n_documents, self.n_features_in_ = X.shape
# Calculating the number of nonzero elements for each row
# using the internal properties of CSR matrices.
Expand Down
5 changes: 1 addition & 4 deletions tweetopic/bayesian/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
def sample_nuts(
initial_position: PyTree,
logdensity_fn: Callable,
data: dict,
seed: int = 0,
n_warmup: int = 1000,
n_samples: int = 1000,
data_axis: int = 0,
) -> tuple[list[PyTree], list[HMCState]]:
"""NUTS sampling loop for any logdensity function that can be JIT compiled
with JAX.
Expand All @@ -50,7 +48,6 @@ def sample_nuts(
State of the Hamiltonian Monte Carlo at each step.
Mostly useful for debugging.
"""
logdensity_fn = spread(partial(logdensity_fn, **data))
rng_key = jax.random.PRNGKey(seed)
warmup_key, sampling_key = jax.random.split(rng_key)
print("Warmup, window adaptation")
Expand Down Expand Up @@ -173,7 +170,7 @@ def sample_meanfield_vi(
return samples, states


def batch_data(rng_key, data, batch_size: int, data_size: int):
def batch_data(rng_key, batch_size: int, data_size: int):
while True:
_, rng_key = jax.random.split(rng_key)
idx = jax.random.choice(
Expand Down

0 comments on commit 97c9605

Please sign in to comment.