From f20f685e3df89692fc880e0406ab084ddf34cd3b Mon Sep 17 00:00:00 2001 From: Alvaro Correa Date: Tue, 6 Feb 2024 17:47:23 +0100 Subject: [PATCH 1/3] fix: add test for TrainingDataProcessor process_raw_data --- tests/test_training_data_processor.py | 33 +++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 tests/test_training_data_processor.py diff --git a/tests/test_training_data_processor.py b/tests/test_training_data_processor.py new file mode 100644 index 0000000..290e778 --- /dev/null +++ b/tests/test_training_data_processor.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import MagicMock + +from ragatouille.data import TrainingDataProcessor + + +@pytest.fixture +def collection(): + return ['doc1', 'doc2', 'doc3'] + + +@pytest.fixture +def queries(): + return ['query1', 'query2'] + + +def test_process_raw_data_without_miner(collection, queries): + processor = TrainingDataProcessor(collection, queries, None) + processor._process_raw_pairs = MagicMock(return_value=None) + + processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./", mine_hard_negatives=False) + + processor._process_raw_pairs.assert_called_once() + + +def test_process_raw_data_with_miner(collection, queries): + negative_miner = MagicMock() + processor = TrainingDataProcessor(collection, queries, negative_miner) + processor._process_raw_pairs = MagicMock(return_value=None) + + processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./") + + processor._process_raw_pairs.assert_called_once() \ No newline at end of file From 2bf837059a14b056177c71bd701e3b19f83bf045 Mon Sep 17 00:00:00 2001 From: Alvaro Correa Date: Tue, 6 Feb 2024 17:48:29 +0100 Subject: [PATCH 2/3] fix: set negative miner min rank only if miner present --- ragatouille/data/training_data_processor.py | 3 ++- tests/test_training_data_processor.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ragatouille/data/training_data_processor.py b/ragatouille/data/training_data_processor.py index 689c4a1..696005a 100644 --- a/ragatouille/data/training_data_processor.py +++ b/ragatouille/data/training_data_processor.py @@ -33,11 +33,12 @@ def process_raw_data( negative_label: int = 0, hard_negative_minimum_rank: int = 10, ): - self.negative_miner.min_rank = hard_negative_minimum_rank if self.negative_miner is None and mine_hard_negatives: raise ValueError( "mine_hard_negatives is True but no negative miner was provided!" ) + if self.negative_miner: + self.negative_miner.min_rank = hard_negative_minimum_rank if data_type == "pairs": self._process_raw_pairs( raw_data=raw_data, diff --git a/tests/test_training_data_processor.py b/tests/test_training_data_processor.py index 290e778..474ea0e 100644 --- a/tests/test_training_data_processor.py +++ b/tests/test_training_data_processor.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import MagicMock +import pytest + from ragatouille.data import TrainingDataProcessor @@ -30,4 +31,4 @@ def test_process_raw_data_with_miner(collection, queries): processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./") - processor._process_raw_pairs.assert_called_once() \ No newline at end of file + processor._process_raw_pairs.assert_called_once() From 9fea94d60608cbd182ab1a4a591283c5ab0fff89 Mon Sep 17 00:00:00 2001 From: bclavie Date: Wed, 7 Feb 2024 19:54:11 +0100 Subject: [PATCH 3/3] chore: linting --- ragatouille/models/colbert.py | 1 + tests/test_training_data_processor.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 4d73c62..3b65289 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -11,6 +11,7 @@ from colbert import Indexer, IndexUpdater, Searcher, Trainer from colbert.infra import ColBERTConfig, Run, RunConfig from colbert.modeling.checkpoint import Checkpoint + from ragatouille.models.base import LateInteractionModel diff --git a/tests/test_training_data_processor.py b/tests/test_training_data_processor.py index 474ea0e..c51a092 100644 --- a/tests/test_training_data_processor.py +++ b/tests/test_training_data_processor.py @@ -7,19 +7,21 @@ @pytest.fixture def collection(): - return ['doc1', 'doc2', 'doc3'] + return ["doc1", "doc2", "doc3"] @pytest.fixture def queries(): - return ['query1', 'query2'] + return ["query1", "query2"] def test_process_raw_data_without_miner(collection, queries): processor = TrainingDataProcessor(collection, queries, None) processor._process_raw_pairs = MagicMock(return_value=None) - processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./", mine_hard_negatives=False) + processor.process_raw_data( + raw_data=[], data_type="pairs", data_dir="./", mine_hard_negatives=False + ) processor._process_raw_pairs.assert_called_once()