Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 1, 2024
1 parent 2c2e92c commit c6cf4d4
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 32 deletions.
27 changes: 17 additions & 10 deletions mess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from tqdm import trange

from tweetopic._doc import init_doc_words
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
predict_doc, sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf)
from tweetopic.bayesian.dmm import (
BayesianDMM,
posterior_predictive,
predict_doc,
sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf,
)
from tweetopic.bayesian.sampling import batch_data, sample_nuts
from tweetopic.func import spread

Expand Down Expand Up @@ -58,23 +62,26 @@ def logprior_fn(params):

def loglikelihood_fn(params, data):
doc_likelihood = jax.vmap(
partial(sparse_multinomial_logpdf, component=params["component"])
partial(sparse_multinomial_logpdf, component=params["component"]),
)
return jnp.sum(
doc_likelihood(
unique_words=data["doc_unique_words"],
unique_word_counts=data["doc_unique_word_counts"],
)
),
)


logdensity_fn(position)

logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
params, data
params,
data,
)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
logprior_fn, loglikelihood_fn, data_size=n_documents
logprior_fn,
loglikelihood_fn,
data_size=n_documents,
)
rng_key = jax.random.PRNGKey(0)
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
Expand All @@ -88,8 +95,8 @@ def loglikelihood_fn(params, data):
)
position = dict(
component=jnp.array(
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
)
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))),
),
)

samples, states = sample_nuts(position, logdensity_fn)
Expand Down
43 changes: 32 additions & 11 deletions tweetopic/_btm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module for utility functions for fitting BTMs"""
"""Module for utility functions for fitting BTMs."""

import random
from typing import Dict, Tuple, TypeVar
Expand All @@ -12,7 +12,8 @@

@njit
def doc_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
(n_max_unique_words,) = doc_unique_words.shape
biterm_counts = dict()
Expand Down Expand Up @@ -43,7 +44,7 @@ def doc_unique_biterms(

@njit
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
"""Adds one counter dict to another in place with Numba"""
"""Adds one counter dict to another in place with Numba."""
for key in source:
if key in dest:
dest[key] += source[key]
Expand All @@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):

@njit
def corpus_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
n_documents, _ = doc_unique_words.shape
biterm_counts = doc_unique_biterms(
doc_unique_words[0], doc_unique_word_counts[0]
doc_unique_words[0],
doc_unique_word_counts[0],
)
for i_doc in range(1, n_documents):
doc_unique_words_i = doc_unique_words[i_doc]
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
doc_biterms = doc_unique_biterms(
doc_unique_words_i, doc_unique_word_counts_i
doc_unique_words_i,
doc_unique_word_counts_i,
)
nb_add_counter(biterm_counts, doc_biterms)
return biterm_counts


@njit
def compute_biterm_set(
biterm_counts: Dict[Tuple[int, int], int]
biterm_counts: Dict[Tuple[int, int], int],
) -> np.ndarray:
return np.array(list(biterm_counts.keys()))

Expand Down Expand Up @@ -116,7 +120,12 @@ def add_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
True,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -129,7 +138,12 @@ def remove_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
False,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -147,7 +161,11 @@ def init_components(
i_topic = random.randint(0, n_components - 1)
biterm_topic_assignments[i_biterm] = i_topic
add_biterm(
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)
return biterm_topic_assignments, topic_word_count, topic_biterm_count

Expand Down Expand Up @@ -448,7 +466,10 @@ def predict_docs(
)
biterms = doc_unique_biterms(words, word_counts)
prob_topic_given_document(
pred, biterms, topic_distribution, topic_word_distribution
pred,
biterms,
topic_distribution,
topic_word_distribution,
)
predictions[i_doc, :] = pred
return predictions
Expand Down
4 changes: 3 additions & 1 deletion tweetopic/_dmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
Model."""

from __future__ import annotations

from math import exp, log
Expand Down
2 changes: 1 addition & 1 deletion tweetopic/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def init_doc_words(
n_docs, _ = doc_term_matrix.shape
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
np.uint32
np.uint32,
)
for i_doc in range(n_docs):
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore
Expand Down
18 changes: 11 additions & 7 deletions tweetopic/btm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
import sklearn
from numpy.typing import ArrayLike

from tweetopic._btm import (compute_biterm_set, corpus_unique_biterms,
fit_model, predict_docs)
from tweetopic._btm import (
compute_biterm_set,
corpus_unique_biterms,
fit_model,
predict_docs,
)
from tweetopic._doc import init_doc_words
from tweetopic.exceptions import NotFittedException
from tweetopic.utils import set_numba_seed


class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator):
"""Implementation of the Biterm Topic Model with Gibbs Sampling
solver.
"""Implementation of the Biterm Topic Model with Gibbs Sampling solver.
Parameters
----------
Expand Down Expand Up @@ -144,7 +147,9 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
X.tolil(),
max_unique_words=max_unique_words,
)
biterms = corpus_unique_biterms(doc_unique_words, doc_unique_word_counts)
biterms = corpus_unique_biterms(
doc_unique_words, doc_unique_word_counts
)
biterm_set = compute_biterm_set(biterms)
self.topic_distribution, self.components_ = fit_model(
n_iter=self.n_iterations,
Expand All @@ -159,8 +164,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
# TODO: Something goes terribly wrong here, fix this

def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray:
"""Predicts probabilities for each document belonging to each
topic.
"""Predicts probabilities for each document belonging to each topic.
Parameters
----------
Expand Down
5 changes: 3 additions & 2 deletions tweetopic/func.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Utility functions for use in the library."""

from functools import wraps
from typing import Callable


def spread(fn: Callable):
"""Creates a new function from the given function so that it takes one
dict (PyTree) and spreads the arguments."""
"""Creates a new function from the given function so that it takes one dict
(PyTree) and spreads the arguments."""

@wraps(fn)
def inner(kwargs):
Expand Down

0 comments on commit c6cf4d4

Please sign in to comment.