diff --git a/ragatouille/data/training_data_processor.py b/ragatouille/data/training_data_processor.py index 689c4a1..696005a 100644 --- a/ragatouille/data/training_data_processor.py +++ b/ragatouille/data/training_data_processor.py @@ -33,11 +33,12 @@ def process_raw_data( negative_label: int = 0, hard_negative_minimum_rank: int = 10, ): - self.negative_miner.min_rank = hard_negative_minimum_rank if self.negative_miner is None and mine_hard_negatives: raise ValueError( "mine_hard_negatives is True but no negative miner was provided!" ) + if self.negative_miner: + self.negative_miner.min_rank = hard_negative_minimum_rank if data_type == "pairs": self._process_raw_pairs( raw_data=raw_data, diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index 4d73c62..3b65289 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -11,6 +11,7 @@ from colbert import Indexer, IndexUpdater, Searcher, Trainer from colbert.infra import ColBERTConfig, Run, RunConfig from colbert.modeling.checkpoint import Checkpoint + from ragatouille.models.base import LateInteractionModel diff --git a/tests/test_training_data_processor.py b/tests/test_training_data_processor.py new file mode 100644 index 0000000..c51a092 --- /dev/null +++ b/tests/test_training_data_processor.py @@ -0,0 +1,36 @@ +from unittest.mock import MagicMock + +import pytest + +from ragatouille.data import TrainingDataProcessor + + +@pytest.fixture +def collection(): + return ["doc1", "doc2", "doc3"] + + +@pytest.fixture +def queries(): + return ["query1", "query2"] + + +def test_process_raw_data_without_miner(collection, queries): + processor = TrainingDataProcessor(collection, queries, None) + processor._process_raw_pairs = MagicMock(return_value=None) + + processor.process_raw_data( + raw_data=[], data_type="pairs", data_dir="./", mine_hard_negatives=False + ) + + processor._process_raw_pairs.assert_called_once() + + +def test_process_raw_data_with_miner(collection, queries): + negative_miner = MagicMock() + processor = TrainingDataProcessor(collection, queries, negative_miner) + processor._process_raw_pairs = MagicMock(return_value=None) + + processor.process_raw_data(raw_data=[], data_type="pairs", data_dir="./") + + processor._process_raw_pairs.assert_called_once()