Skip to content

Commit

Permalink
Changes in API, started writing minibatch sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Jun 8, 2023
1 parent 21316a7 commit 27dd112
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 15 deletions.
11 changes: 8 additions & 3 deletions mess.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

texts = [line for line in open("processed_sample.txt")]

vectorizer = CountVectorizer(max_features=100, max_df=0.3, min_df=10)
X = vectorizer.fit_transform(random.sample(texts, 10_000))
vectorizer = CountVectorizer(max_features=4000, max_df=0.3, min_df=10)
X = vectorizer.fit_transform(random.sample(texts, 20_000))

model = BayesianDMM(
n_components=5,
Expand Down Expand Up @@ -64,7 +64,12 @@

X.shape

predict_doc()
predict_doc(
unique_words=doc_unique_words[0],
unique_word_counts=doc_unique_word_counts[0],
components=components[0],
weights=weights[0],
)

try:
predict_one_doc(doc=docs[0], samples=np.array(model.samples[:2]))
Expand Down
48 changes: 36 additions & 12 deletions tweetopic/bayesian/dmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,10 @@ def posterior_predictive(
return predict_all(doc_unique_words, doc_unique_word_counts)


def dirichlet_multinomial_mixture_logpdf(
def dmm_loglikelihood(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
):
"""Calculates logdensity of the DMM at a given point in parameter space."""
docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1)
n_docs = doc_unique_words.shape[0]

def doc_likelihood(doc):
unique_words, unique_word_counts = doc
Expand All @@ -137,7 +135,11 @@ def doc_likelihood(doc):
) + jnp.log(weights)
return jax.scipy.special.logsumexp(component_logprobs)

likelihood = jnp.sum(jax.lax.map(doc_likelihood, docs))
loglikelihood = jnp.sum(jax.lax.map(doc_likelihood, docs))
return loglikelihood


def dmm_logprior(components, weights, alpha, beta, n_docs):
components_prior = jnp.sum(
jax.lax.map(
partial(symmetric_dirichlet_logpdf, alpha=alpha), components
Expand All @@ -146,7 +148,24 @@ def doc_likelihood(doc):
weights_prior = symmetric_dirichlet_multinomial_logpdf(
weights, n=jnp.float64(n_docs), alpha=beta
)
return likelihood + components_prior + weights_prior
return components_prior + weights_prior


def dmm_logpdf(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
):
"""Calculates logdensity of the DMM at a given point in parameter space."""
n_docs = doc_unique_words.shape[0]
loglikelihood = dmm_loglikelihood(
components,
weights,
doc_unique_words,
doc_unique_word_counts,
alpha,
beta,
)
logprior = dmm_logprior(components, weights, alpha, beta, n_docs)
return logprior + loglikelihood


class BayesianDMM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator):
Expand Down Expand Up @@ -223,16 +242,21 @@ def fit(self, X, y=None):
alpha=self.alpha,
beta=self.beta,
)
logdensity_fn = spread(
partial(
dirichlet_multinomial_mixture_logpdf,
logdensity_fn = partial(
dmm_logpdf,
doc_unique_words=doc_unique_words,
doc_unique_word_counts=doc_unique_word_counts,
alpha=self.alpha,
beta=self.beta,
)
samples = self.sampler(
initial_position,
logdensity_fn,
data=dict(
doc_unique_words=doc_unique_words,
doc_unique_word_counts=doc_unique_word_counts,
alpha=self.alpha,
beta=self.beta,
)
),
)
samples = self.sampler(initial_position, logdensity_fn)
self.samples = samples
return self

Expand Down
68 changes: 68 additions & 0 deletions tweetopic/bayesian/sampling.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
"""Sampling utilities using jax and blackjax."""
from functools import partial
from typing import Any, Callable

import blackjax
import jax
import jax.numpy as jnp
from blackjax.mcmc.hmc import HMCState
from blackjax.types import PyTree
from blackjax.vi.meanfield_vi import MFVIState
from optax import adam
from tqdm import trange

from tweetopic.func import spread

Sampler = Callable[..., tuple[list[PyTree], Any]]


def sample_nuts(
initial_position: PyTree,
logdensity_fn: Callable,
data: dict,
seed: int = 0,
n_warmup: int = 1000,
n_samples: int = 1000,
data_axis: int = 0,
) -> tuple[list[PyTree], list[HMCState]]:
"""NUTS sampling loop for any logdensity function that can be JIT compiled
with JAX.
Expand All @@ -44,6 +50,7 @@ def sample_nuts(
State of the Hamiltonian Monte Carlo at each step.
Mostly useful for debugging.
"""
logdensity_fn = spread(partial(logdensity_fn, **data))
rng_key = jax.random.PRNGKey(seed)
warmup_key, sampling_key = jax.random.split(rng_key)
print("Warmup, window adaptation")
Expand All @@ -64,10 +71,12 @@ def sample_nuts(
def sample_sgld(
initial_position: dict,
logdensity_fn: Callable,
data: dict,
seed: int = 0,
n_samples: int = 1000,
initial_step_size: float = 1000,
decay: float = 2.5,
data_axis: int = 0,
) -> tuple[list[PyTree], None]:
"""Stochastic Gradient Langevin Dynamics sampling loop with decaying step size
for any logdensity function that is differentiable with JAX.
Expand Down Expand Up @@ -98,6 +107,7 @@ def sample_sgld(
states: None
Ignored only returned for compatibilty.
"""
logdensity_fn = spread(partial(logdensity_fn, **data))
rng_key = jax.random.PRNGKey(seed)
num_training_steps = n_samples
schedule_fn = lambda k: initial_step_size * (k ** (-decay))
Expand All @@ -116,9 +126,12 @@ def sample_sgld(
def sample_pathfinder(
initial_position: dict,
logdensity_fn: Callable,
data: dict,
seed: int = 0,
n_samples: int = 1000,
data_axis: int = 0,
) -> list[PyTree]:
logdensity_fn = spread(partial(logdensity_fn, **data))
rng_key = jax.random.PRNGKey(seed)
optim_key, sampling_key = jax.random.split(rng_key)
pathfinder = blackjax.pathfinder(logdensity_fn)
Expand All @@ -134,12 +147,15 @@ def sample_pathfinder(
def sample_meanfield_vi(
initial_position: dict,
logdensity_fn: Callable,
data: dict,
seed: int = 0,
n_iter: int = 20_000,
n_samples: int = 1000,
n_optim_samples: int = 20,
learning_rate: float = 0.08,
data_axis: int = 0,
) -> tuple[list[PyTree], list[MFVIState]]:
logdensity_fn = spread(partial(logdensity_fn, **data))
rng_key = jax.random.PRNGKey(seed)
optim_key, sampling_key = jax.random.split(rng_key)
optimizer = adam(learning_rate)
Expand All @@ -155,3 +171,55 @@ def sample_meanfield_vi(
states.append(state)
samples = mfvi.sample(sampling_key, state, num_samples=n_samples)
return samples, states


def batch_data(rng_key, data, batch_size: int, data_size: int):
while True:
_, rng_key = jax.random.split(rng_key)
idx = jax.random.choice(
key=rng_key, a=jnp.arange(data_size), shape=(batch_size,)
)
yield idx


def get_batch(idx, data: dict, data_axis: int):
return {
key: jnp.take(value, idx, axis=data_axis)
for key, value in data.items()
}


# TODO
def sample_minibatch_hmc(
initial_position: PyTree,
logdensity_fn: Callable,
data: dict,
seed: int = 0,
batch_size: int = 512,
step_size: float = 0.001,
n_warmup: int = 100,
n_samples: int = 1000,
data_axis=0,
) -> tuple[list[PyTree], list[HMCState]]:
rng_key = jax.random.PRNGKey(seed)
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
batches = batch_data(
batch_key, batch_size, data_size=len(data[list(data.keys())[0]])
)
print("Warmup, window adaptation")
warmup_batch = get_batch(next(batches), data, data_axis=data_axis)
warmup = blackjax.window_adaptation(
blackjax.hmc, partial(logdensity_fn, **warmup_batch)
)
(state, parameters), _ = warmup.run(
warmup_key, initial_position, num_steps=n_warmup
)
sghmc = blackjax.sghmc()
kernel = jax.jit(blackjax.nuts(logdensity_fn, **parameters).step)
states = []
for i in trange(n_samples, desc="Sampling"):
_, sampling_key = jax.random.split(sampling_key)
state, _ = kernel(sampling_key, state)
states.append(state)
samples = [state.position for state in states]
return samples, states

0 comments on commit 27dd112

Please sign in to comment.