Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 3, 2023
1 parent 7eb98f4 commit b81bf3a
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 25 deletions.
6 changes: 5 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = ["sphinx.ext.autodoc", "sphinx.ext.coverage", "sphinx.ext.napoleon"]
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
]

templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
Expand Down
2 changes: 1 addition & 1 deletion tweetopic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from tweetopic.dmm import DMM # noqa: F401
from tweetopic.btm import BTM # noqa: F401
from tweetopic.dmm import DMM # noqa: F401
from tweetopic.pipeline import TopicPipeline # noqa: F401
44 changes: 33 additions & 11 deletions tweetopic/_btm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Module for utility functions for fitting BTMs"""
"""Module for utility functions for fitting BTMs."""

import random
from typing import Dict, Tuple, TypeVar

import numba
import numpy as np
from numba import njit

from tweetopic._prob import norm_prob, sample_categorical


@njit
def doc_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
(n_max_unique_words,) = doc_unique_words.shape
biterm_counts = dict()
Expand Down Expand Up @@ -42,7 +44,7 @@ def doc_unique_biterms(

@njit
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
"""Adds one counter dict to another in place with Numba"""
"""Adds one counter dict to another in place with Numba."""
for key in source:
if key in dest:
dest[key] += source[key]
Expand All @@ -52,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):

@njit
def corpus_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
n_documents, _ = doc_unique_words.shape
biterm_counts = doc_unique_biterms(
doc_unique_words[0], doc_unique_word_counts[0]
doc_unique_words[0],
doc_unique_word_counts[0],
)
for i_doc in range(1, n_documents):
doc_unique_words_i = doc_unique_words[i_doc]
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
doc_biterms = doc_unique_biterms(
doc_unique_words_i, doc_unique_word_counts_i
doc_unique_words_i,
doc_unique_word_counts_i,
)
nb_add_counter(biterm_counts, doc_biterms)
return biterm_counts


@njit
def compute_biterm_set(
biterm_counts: Dict[Tuple[int, int], int]
biterm_counts: Dict[Tuple[int, int], int],
) -> np.ndarray:
return np.array(list(biterm_counts.keys()))

Expand Down Expand Up @@ -115,7 +120,12 @@ def add_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
True,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -128,7 +138,12 @@ def remove_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
False,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


Expand All @@ -146,7 +161,11 @@ def init_components(
i_topic = random.randint(0, n_components - 1)
biterm_topic_assignments[i_biterm] = i_topic
add_biterm(
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)
return biterm_topic_assignments, topic_word_count, topic_biterm_count

Expand Down Expand Up @@ -360,7 +379,10 @@ def predict_docs(
)
biterms = doc_unique_biterms(words, word_counts)
prob_topic_given_document(
pred, biterms, topic_distribution, topic_word_distribution
pred,
biterms,
topic_distribution,
topic_word_distribution,
)
predictions[i_doc, :] = pred
return predictions
Expand Down
5 changes: 3 additions & 2 deletions tweetopic/_dmm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
Model."""
from __future__ import annotations

from math import exp, log

import numpy as np
from numba import njit

from tweetopic._prob import sample_categorical, norm_prob
from tweetopic._prob import norm_prob, sample_categorical


@njit
Expand Down
2 changes: 1 addition & 1 deletion tweetopic/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def init_doc_words(
n_docs, _ = doc_term_matrix.shape
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
np.uint32
np.uint32,
)
for i_doc in range(n_docs):
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore
Expand Down
18 changes: 11 additions & 7 deletions tweetopic/btm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
import scipy.sparse as spr
import sklearn
from numpy.typing import ArrayLike
from tweetopic._btm import (compute_biterm_set, corpus_unique_biterms,
fit_model, predict_docs)

from tweetopic._btm import (
compute_biterm_set,
corpus_unique_biterms,
fit_model,
predict_docs,
)
from tweetopic._doc import init_doc_words
from tweetopic.exceptions import NotFittedException


class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator):
"""Implementation of the Biterm Topic Model with Gibbs Sampling
solver.
"""Implementation of the Biterm Topic Model with Gibbs Sampling solver.
Parameters
----------
Expand Down Expand Up @@ -136,7 +140,8 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
max_unique_words=max_unique_words,
)
biterms = corpus_unique_biterms(
doc_unique_words, doc_unique_word_counts
doc_unique_words,
doc_unique_word_counts,
)
biterm_set = compute_biterm_set(biterms)
self.topic_distribution, self.components_ = fit_model(
Expand All @@ -152,8 +157,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
# TODO: Something goes terribly wrong here, fix this

def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray:
"""Predicts probabilities for each document belonging to each
topic.
"""Predicts probabilities for each document belonging to each topic.
Parameters
----------
Expand Down
6 changes: 4 additions & 2 deletions tweetopic/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def fit(self, texts: Iterable[str]) -> TopicPipeline:
return self

def fit_transform(
self, texts: Iterable[str]
self,
texts: Iterable[str],
) -> Union[ArrayLike, spr.spmatrix]:
"""Fits vectorizer and topic model and transforms the given text.
Expand All @@ -65,7 +66,8 @@ def fit_transform(
return self.topic_model.fit_transform(doc_term_matrix)

def transform(
self, texts: Iterable[str]
self,
texts: Iterable[str],
) -> Union[ArrayLike, spr.spmatrix]:
"""Transforms given texts with the fitted pipeline.
Expand Down

0 comments on commit b81bf3a

Please sign in to comment.