From 36442257ace44927ff65b0b2831647b1aebdc5fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maik=20Fr=C3=B6be?= Date: Mon, 16 Sep 2024 03:52:21 +0200 Subject: [PATCH] Prepare usage of pre-tokenized index #50 --- ir_axioms/axiom/preconditions.py | 4 ++-- ir_axioms/axiom/query_aspects.py | 12 ++++++---- ir_axioms/backend/pyterrier/transformers.py | 26 ++++++++++++--------- ir_axioms/model/context.py | 3 +++ 4 files changed, 28 insertions(+), 17 deletions(-) diff --git a/ir_axioms/axiom/preconditions.py b/ir_axioms/axiom/preconditions.py index f3c3ae3..fc461f1 100644 --- a/ir_axioms/axiom/preconditions.py +++ b/ir_axioms/axiom/preconditions.py @@ -12,8 +12,8 @@ def approximately_same_length( margin_fraction: float, ) -> bool: return approximately_equal( - len(context.terms(document1)), - len(context.terms(document2)), + context.document_length(document1), + context.document_length(document2), margin_fraction=margin_fraction ) diff --git a/ir_axioms/axiom/query_aspects.py b/ir_axioms/axiom/query_aspects.py index ee839ae..949fdd4 100644 --- a/ir_axioms/axiom/query_aspects.py +++ b/ir_axioms/axiom/query_aspects.py @@ -197,10 +197,14 @@ def preference( document2: RankedDocument ): query_terms = context.term_set(query) - document1_terms = context.term_set(document1) - document2_terms = context.term_set(document2) - s1 = query_terms.issubset(document1_terms) - s2 = query_terms.issubset(document2_terms) + s1, s2 = set(), set() + + for query_term in query_terms: + if context.term_frequency(document1, query_term) > 0: + s1.add(query_term) + if context.term_frequency(document2, query_term) > 0: + s2.add(query_term) + return strictly_greater(s1, s2) diff --git a/ir_axioms/backend/pyterrier/transformers.py b/ir_axioms/backend/pyterrier/transformers.py index 59aae7b..b1b02c0 100644 --- a/ir_axioms/backend/pyterrier/transformers.py +++ b/ir_axioms/backend/pyterrier/transformers.py @@ -68,9 +68,10 @@ def transform(self, topics_or_res: DataFrame) -> DataFrame: class AxiomTransformer(PerGroupTransformer, ABC): - index: Union[Index, IndexRef, Path, str] + index: Optional[Union[Index, IndexRef, Path, str]] = None dataset: Optional[Union[Dataset, str, IRDSDataset]] = None contents_accessor: Optional[ContentsAccessor] = "text" + context: Optional[IndexContext] = None tokeniser: Optional[Tokeniser] = None cache_dir: Optional[Path] = None verbose: bool = False @@ -80,15 +81,17 @@ class AxiomTransformer(PerGroupTransformer, ABC): optional_group_columns = {"qid", "name"} unit = "query" - @cached_property + @property def _context(self) -> IndexContext: - return TerrierIndexContext( - index_location=self.index, - dataset=self.dataset, - contents_accessor=self.contents_accessor, - tokeniser=self.tokeniser, - cache_dir=self.cache_dir, - ) + if not self.context: + self.context = TerrierIndexContext( + index_location=self.index, + dataset=self.dataset, + contents_accessor=self.contents_accessor, + tokeniser=self.tokeniser, + cache_dir=self.cache_dir, + ) + return self.context @final def transform_group(self, topics_or_res: DataFrame) -> DataFrame: @@ -124,8 +127,9 @@ class KwikSortReranker(AxiomTransformer): description = "Reranking query axiomatically" axiom: AxiomLike - index: Union[Index, IndexRef, Path, str] + index: Optional[Union[Index, IndexRef, Path, str]] = None dataset: Optional[Union[Dataset, str, IRDSDataset]] = None + context: Optional[IndexContext] = None contents_accessor: Optional[ContentsAccessor] = "text" pivot_selection: PivotSelection = RandomPivotSelection() tokeniser: Optional[Tokeniser] = None @@ -170,8 +174,8 @@ class AggregatedAxiomaticPreferences(AxiomTransformer): description = "Aggregating query axiom preferences" axioms: Sequence[AxiomLike] - index: Union[Index, IndexRef, Path, str] aggregations: Sequence[Callable[[Sequence[float]], float]] + index: Optional[Union[Index, IndexRef, Path, str]] = None dataset: Optional[Union[Dataset, str, IRDSDataset]] = None contents_accessor: Optional[ContentsAccessor] = "text" filter_pairs: Optional[Callable[ diff --git a/ir_axioms/model/context.py b/ir_axioms/model/context.py index d183d0b..9c5afa3 100644 --- a/ir_axioms/model/context.py +++ b/ir_axioms/model/context.py @@ -64,6 +64,9 @@ def term_set( ) -> FrozenSet[str]: return frozenset(self.terms(query_or_document)) + def document_length(self, document: Document): + return len(self.terms(document)) + @lru_cache(None) def term_frequency( self,