Skip to content

Commit

Permalink
some optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
SkBlaz committed Nov 30, 2024
1 parent 92cd373 commit bbc83ea
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
85 changes: 45 additions & 40 deletions rakun2/class_rakun.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import scipy.sparse as sps
import fitz

logging.basicConfig(format="%(asctime)s - %(message)s",
Expand Down Expand Up @@ -126,63 +127,67 @@ def visualize_network(
plt.show()

def compute_tf_scores(self, document: str = None) -> None:
""" Compute TF scores """
"""Compute Term Frequency (TF) scores efficiently."""

if document is not None:
self.tokens = self.pattern.findall(document)

term_counter: Any = Counter()
for term in self.tokens:
term_counter.update({term: 1})
term_counter = Counter(self.tokens)
self.term_counts = dict(term_counter)
self.sorted_terms_tf = sorted(term_counter.items(),
key=itemgetter(1),
reverse=True)

def pagerank_scipy_adapted(self,
token_graph: nx.Graph,
alpha: float = 0.85,
personalization: np.array = None,
max_iter: int = 64,
tol: float = 1.0e-2,
weight: str = "weight"):
self.sorted_terms_tf = term_counter.most_common()


def pagerank_scipy_adapted(
self,
token_graph: nx.Graph,
alpha: float = 0.85,
personalization: dict = None,
max_iter: int = 64,
tol: float = 1.0e-2,
weight: str = "weight",
):
"""
Adapted from NetworkX's nx.pagerank; we know how token graphs look like
hence can omit some intermediary processing to make it a bit faster.
The convergence criterion could also be adapted.
Further optimized PageRank computation for token graphs.
"""

num_nodes = len(token_graph)
if num_nodes == 0:
return {}

nodelist = list(token_graph)
token_sparse_matrix = nx.to_scipy_sparse_array(token_graph,
nodelist=nodelist,
weight=weight,
dtype=np.float32)
normalization_array = token_sparse_matrix.sum(axis=1)
normalization_array[normalization_array != 0] = np.divide(
1.0, normalization_array[normalization_array != 0])
diagonal_norm = sp.sparse.spdiags([normalization_array],
0,
format="csr")
token_sparse_matrix, x_iteration = np.dot(
diagonal_norm,
token_sparse_matrix), np.repeat(1.0 / num_nodes, num_nodes)
pers_array = np.array([personalization.get(n, 0) for n in nodelist],
dtype=np.float32)
pers_array = pers_array / np.sum(pers_array)
token_sparse_matrix = nx.to_scipy_sparse_array(
token_graph, nodelist=nodelist, weight=weight, dtype=np.float32
)

# Normalize rows of the sparse matrix
row_sums = np.array(token_sparse_matrix.sum(axis=1)).flatten()
nonzero_indices = row_sums > 0
row_sums[nonzero_indices] = 1.0 / row_sums[nonzero_indices]
token_sparse_matrix.data *= np.repeat(row_sums, np.diff(token_sparse_matrix.indptr))

# Precompute the transpose of the sparse matrix
token_sparse_matrix_T = token_sparse_matrix.T

# Personalization vector
pers_array = np.zeros(num_nodes, dtype=np.float32)
for idx, node in enumerate(nodelist):
pers_array[idx] = personalization.get(node, 0)
pers_array /= pers_array.sum()

# Initialize scores
x_iteration = np.full(num_nodes, 1.0 / num_nodes, dtype=np.float32)

# Perform the power iteration for PageRank
for _ in range(max_iter):
xlast = x_iteration
x_iteration = alpha * (x_iteration @ token_sparse_matrix) + (
1 - alpha) * pers_array
err = np.sum(np.absolute(x_iteration - xlast))
if err < tol:
return dict(zip(nodelist, map(np.float32, x_iteration)))
x_iteration = alpha * (token_sparse_matrix_T @ x_iteration) + (1 - alpha) * pers_array
# Fast L1 norm calculation for convergence
if np.sum(np.abs(x_iteration - xlast)) < tol:
break

return dict(zip(nodelist, x_iteration))


return dict(zip(nodelist, [1] * num_nodes))

def get_document_graph(self, weight: int = 1):
""" A method for obtaining the token graph """
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def parse_requirements(file):


setup(name='rakun2',
version='0.25',
version='0.26',
description=
"RaKUn 2.0; Better faster stronger lighter",
url='http://github.com/skblaz/rakun2',
Expand Down

0 comments on commit bbc83ea

Please sign in to comment.