diff --git a/changelog/7200.bugfix.md b/changelog/7200.bugfix.md new file mode 100644 index 000000000000..cd10e8aafc8c --- /dev/null +++ b/changelog/7200.bugfix.md @@ -0,0 +1 @@ +Fix a bug because of which only one retrieval intent was present in `all_retrieval_intent` key of the output of `ResponseSelector` even if there were multiple retrieval intents present in the training data. diff --git a/rasa/nlu/selectors/response_selector.py b/rasa/nlu/selectors/response_selector.py index c6789898e9ab..35491a6a3b2f 100644 --- a/rasa/nlu/selectors/response_selector.py +++ b/rasa/nlu/selectors/response_selector.py @@ -300,7 +300,12 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData: """Prepares data for training. Performs sanity checks on training data, extracts encodings for labels. + + Args: + training_data: training data to preprocessed. """ + # Collect all retrieval intents present in the data before filtering + self.all_retrieval_intents = list(training_data.retrieval_intents) if self.retrieval_intent: training_data = training_data.filter_training_examples( @@ -321,7 +326,6 @@ def preprocess_train_data(self, training_data: TrainingData) -> RasaModelData: ) self.responses = training_data.responses - self.all_retrieval_intents = list(training_data.retrieval_intents) if not label_id_index_mapping: # no labels are present to train diff --git a/tests/nlu/selectors/test_selectors.py b/tests/nlu/selectors/test_selectors.py index 1203407686ae..d25d28c0367e 100644 --- a/tests/nlu/selectors/test_selectors.py +++ b/tests/nlu/selectors/test_selectors.py @@ -18,6 +18,8 @@ CHECKPOINT_MODEL, ) from rasa.nlu.selectors.response_selector import ResponseSelector +from rasa.shared.nlu.training_data.message import Message +from rasa.shared.nlu.training_data.training_data import TrainingData @pytest.mark.parametrize( @@ -95,6 +97,34 @@ def test_train_selector(pipeline, component_builder, tmpdir): assert rank.get("intent_response_key") is not None +def test_preprocess_selector_multiple_retrieval_intents(): + + # use some available data + training_data = rasa.shared.nlu.training_data.loading.load_data( + "data/examples/rasa/demo-rasa.md" + ) + training_data_responses = rasa.shared.nlu.training_data.loading.load_data( + "data/examples/rasa/demo-rasa-responses.md" + ) + training_data_extra_intent = TrainingData( + [ + Message.build( + text="Is it possible to detect the version?", intent="faq/q1" + ), + Message.build(text="How can I get a new virtual env", intent="faq/q2"), + ] + ) + training_data = training_data.merge(training_data_responses).merge( + training_data_extra_intent + ) + + response_selector = ResponseSelector() + + response_selector.preprocess_train_data(training_data) + + assert sorted(response_selector.all_retrieval_intents) == ["chitchat", "faq"] + + @pytest.mark.parametrize( "use_text_as_label, label_values", [