Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 2, 2023
1 parent 327a529 commit 485a9d2
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 61 deletions.
27 changes: 17 additions & 10 deletions mess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from tqdm import trange

from tweetopic._doc import init_doc_words
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
predict_doc, sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf)
from tweetopic.bayesian.dmm import (
BayesianDMM,
posterior_predictive,
predict_doc,
sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf,
)
from tweetopic.bayesian.sampling import batch_data, sample_nuts
from tweetopic.func import spread

Expand Down Expand Up @@ -58,23 +62,26 @@ def logprior_fn(params):

def loglikelihood_fn(params, data):
doc_likelihood = jax.vmap(
partial(sparse_multinomial_logpdf, component=params["component"])
partial(sparse_multinomial_logpdf, component=params["component"]),
)
return jnp.sum(
doc_likelihood(
unique_words=data["doc_unique_words"],
unique_word_counts=data["doc_unique_word_counts"],
)
),
)


logdensity_fn(position)

logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
params, data
params,
data,
)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
logprior_fn, loglikelihood_fn, data_size=n_documents
logprior_fn,
loglikelihood_fn,
data_size=n_documents,
)
rng_key = jax.random.PRNGKey(0)
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
Expand All @@ -88,8 +95,8 @@ def loglikelihood_fn(params, data):
)
position = dict(
component=jnp.array(
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
)
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))),
),
)

samples, states = sample_nuts(position, logdensity_fn)
Expand Down
43 changes: 32 additions & 11 deletions tweetopic/_btm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module for utility functions for fitting BTMs"""
"""Module for utility functions for fitting BTMs."""

import random
from typing import Dict, Tuple, TypeVar
Expand All @@ -12,7 +12,8 @@

@njit
def doc_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
(n_max_unique_words,) = doc_unique_words.shape
biterm_counts = dict()
Expand Down Expand Up @@ -43,7 +44,7 @@ def doc_unique_biterms(

@njit
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
"""Adds one counter dict to another in place with Numba"""
"""Adds one counter dict to another in place with Numba."""
for key in source:
if key in dest:
dest[key] += source[key]
Expand All @@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):

@njit
def corpus_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
n_documents, _ = doc_unique_words.shape
biterm_counts = doc_unique_biterms(
doc_unique_words[0], doc_unique_word_counts[0]
doc_unique_words[0],
doc_unique_word_counts[0],
)
for i_doc in range(1, n_documents):
doc_unique_words_i = doc_unique_words[i_doc]
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
doc_biterms = doc_unique_biterms(
doc_unique_words_i, doc_unique_word_counts_i
doc_unique_words_i,
doc_unique_word_counts_i,
)
nb_add_counter(biterm_counts, doc_biterms)
return biterm_counts


@njit
def compute_biterm_set(
biterm_counts: Dict[Tuple[int, int], int]
biterm_counts: Dict[Tuple[int, int], int],
) -> np.ndarray:
return np.array(list(biterm_counts.keys()))

Expand Down Expand Up @@ -116,7 +120,12 @@ def add_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
True,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -129,7 +138,12 @@ def remove_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
False,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -147,7 +161,11 @@ def init_components(
i_topic = random.randint(0, n_components - 1)
biterm_topic_assignments[i_biterm] = i_topic
add_biterm(
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)
return biterm_topic_assignments, topic_word_count, topic_biterm_count

Expand Down Expand Up @@ -448,7 +466,10 @@ def predict_docs(
)
biterms = doc_unique_biterms(words, word_counts)
prob_topic_given_document(
pred, biterms, topic_distribution, topic_word_distribution
pred,
biterms,
topic_distribution,
topic_word_distribution,
)
predictions[i_doc, :] = pred
return predictions
Expand Down
3 changes: 2 additions & 1 deletion tweetopic/_dmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
Model."""
from __future__ import annotations

from math import exp, log
Expand Down
2 changes: 1 addition & 1 deletion tweetopic/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def init_doc_words(
n_docs, _ = doc_term_matrix.shape
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
np.uint32
np.uint32,
)
for i_doc in range(n_docs):
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore
Expand Down
68 changes: 46 additions & 22 deletions tweetopic/bayesian/dmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""JAX implementation of probability densities and parameter initialization
for the Dirichlet Multinomial Mixture Model."""
"""JAX implementation of probability densities and parameter initialization for
the Dirichlet Multinomial Mixture Model."""
from functools import partial

import jax
Expand All @@ -22,12 +22,18 @@ def symmetric_dirichlet_multinomial_mean(alpha: float, n: int, K: int):


def init_parameters(
n_docs: int, n_vocab: int, n_components: int, alpha: float, beta: float
n_docs: int,
n_vocab: int,
n_components: int,
alpha: float,
beta: float,
) -> dict:
"""Initializes the parameters of the dmm to the mean of the prior."""
return dict(
weights=symmetric_dirichlet_multinomial_mean(
alpha, n_docs, n_components
alpha,
n_docs,
n_components,
),
components=np.broadcast_to(
scipy.stats.dirichlet.mean(np.full(n_vocab, beta)),
Expand All @@ -41,13 +47,15 @@ def sparse_multinomial_logpdf(
unique_words,
unique_word_counts,
):
"""Calculates joint multinomial probability of a sparse representation"""
"""Calculates joint multinomial probability of a sparse representation."""
unique_word_counts = jnp.float64(unique_word_counts)
n_words = jnp.sum(unique_word_counts)
n_factorial = jax.lax.lgamma(n_words + 1)
word_count_factorial = jax.lax.lgamma(unique_word_counts + 1)
word_count_factorial = jnp.where(
unique_word_counts != 0, word_count_factorial, 0
unique_word_counts != 0,
word_count_factorial,
0,
)
denominator = jnp.sum(word_count_factorial)
probs = component[unique_words]
Expand Down Expand Up @@ -84,18 +92,18 @@ def symmetric_dirichlet_multinomial_logpdf(x, n, alpha):


def predict_doc(components, weights, unique_words, unique_word_counts):
"""Predicts likelihood of a document belonging to
each cluster based on given parameters."""
"""Predicts likelihood of a document belonging to each cluster based on
given parameters."""
component_logpdf = partial(
sparse_multinomial_logpdf,
unique_words=unique_words,
unique_word_counts=unique_word_counts,
)
component_logprobs = jax.lax.map(component_logpdf, components) + jnp.log(
weights
weights,
)
norm_probs = jnp.exp(
component_logprobs - jax.scipy.special.logsumexp(component_logprobs)
component_logprobs - jax.scipy.special.logsumexp(component_logprobs),
)
return norm_probs

Expand All @@ -106,24 +114,31 @@ def predict_one(unique_words, unique_word_counts, components, weights):
predict_doc,
unique_words=unique_words,
unique_word_counts=unique_word_counts,
)
),
)(components, weights)


def posterior_predictive(
doc_unique_words, doc_unique_word_counts, components, weights
doc_unique_words,
doc_unique_word_counts,
components,
weights,
):
"""Predicts probability of a document belonging to each component
for all posterior samples.
"""
"""Predicts probability of a document belonging to each component for all
posterior samples."""
predict_all = jax.vmap(
partial(predict_one, components=components, weights=weights)
partial(predict_one, components=components, weights=weights),
)
return predict_all(doc_unique_words, doc_unique_word_counts)


def dmm_loglikelihood(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
components,
weights,
doc_unique_words,
doc_unique_word_counts,
alpha,
beta,
):
docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1)

Expand All @@ -135,7 +150,8 @@ def doc_likelihood(doc):
unique_word_counts=unique_word_counts,
)
component_logprobs = jax.lax.map(
component_logpdf, components
component_logpdf,
components,
) + jnp.log(weights)
return jax.scipy.special.logsumexp(component_logprobs)

Expand All @@ -146,17 +162,25 @@ def doc_likelihood(doc):
def dmm_logprior(components, weights, alpha, beta, n_docs):
components_prior = jnp.sum(
jax.lax.map(
partial(symmetric_dirichlet_logpdf, alpha=alpha), components
)
partial(symmetric_dirichlet_logpdf, alpha=alpha),
components,
),
)
weights_prior = symmetric_dirichlet_multinomial_logpdf(
weights, n=jnp.float64(n_docs), alpha=beta
weights,
n=jnp.float64(n_docs),
alpha=beta,
)
return components_prior + weights_prior


def dmm_logpdf(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
components,
weights,
doc_unique_words,
doc_unique_word_counts,
alpha,
beta,
):
"""Calculates logdensity of the DMM at a given point in parameter space."""
n_docs = doc_unique_words.shape[0]
Expand Down
Loading

0 comments on commit 485a9d2

Please sign in to comment.