Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MLLM ambiguity calculation #825

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions annif/lexical/tokenset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
subject_id: int | None = None,
is_pref: bool = False,
) -> None:
self._tokens = set(tokens)
self._tokens = frozenset(tokens)
self.key = tokens[0] if len(tokens) else None
self.subject_id = subject_id
self.is_pref = is_pref
Expand All @@ -31,6 +31,10 @@ def __len__(self) -> int:
def __iter__(self):
return iter(self._tokens)

@property
def tokens(self) -> frozenset:
return self._tokens

def contains(self, other: TokenSet) -> bool:
"""Returns True iff the tokens in the other TokenSet are all
included within this TokenSet."""
Expand Down Expand Up @@ -68,21 +72,25 @@ def _find_subj_tsets(self, tset: TokenSet) -> dict[int | None, TokenSet]:

return subj_tsets

def _find_subj_ambiguity(self, tsets):
def _find_subj_ambiguity(self, tsets: list[TokenSet]):
"""calculate the ambiguity values (the number of other TokenSets
that also match the same tokens) for the given TokenSets and return
them as a dict-like object (subject_id : ambiguity_value)"""

# group the TokenSets by their tokens, so that TokenSets with
# identical tokens can be considered together all in one go
elim_tsets = collections.defaultdict(set)
for ts in tsets:
elim_tsets[ts.tokens].add(ts.subject_id)

subj_ambiguity = collections.Counter()

subj_ambiguity.update(
[
ts.subject_id
for ts in tsets
for other in tsets
if ts != other and other.contains(ts)
]
)
for tokens1, subjs1 in elim_tsets.items():
for tokens2, subjs2 in elim_tsets.items():
if not tokens2.issuperset(tokens1):
continue
for subj in subjs1:
subj_ambiguity[subj] += len(subjs2) - int(subj in subjs2)

return subj_ambiguity

Expand Down
Loading