diff --git a/rasa/shared/nlu/training_data/formats/rasa_yaml.py b/rasa/shared/nlu/training_data/formats/rasa_yaml.py index ce254b109b01..c41884b94386 100644 --- a/rasa/shared/nlu/training_data/formats/rasa_yaml.py +++ b/rasa/shared/nlu/training_data/formats/rasa_yaml.py @@ -8,6 +8,7 @@ from rasa.shared.exceptions import YamlException from rasa.shared.utils import validation from ruamel.yaml import StringIO +from ruamel.yaml.scalarstring import LiteralScalarString from rasa.shared.constants import ( DOCS_URL_TRAINING_DATA, @@ -36,6 +37,7 @@ KEY_LOOKUP = "lookup" KEY_LOOKUP_EXAMPLES = "examples" KEY_METADATA = "metadata" +KEY_METADATA_EXAMPLE = "example" MULTILINE_TRAINING_EXAMPLE_LEADING_SYMBOL = "-" @@ -470,21 +472,57 @@ def process_training_examples_by_key( key_examples: Text, example_extraction_predicate=lambda x: x, ) -> List[OrderedDict]: - from ruamel.yaml.scalarstring import LiteralScalarString - result = [] + for entity_key, examples in training_examples.items(): + converted_examples = [] + render_as_objects = False + for example in examples: + converted_example = { + KEY_INTENT_TEXT: example_extraction_predicate(example) + } - converted_examples = [ - TrainingDataWriter.generate_list_item( - example_extraction_predicate(example).strip(STRIP_SYMBOLS) - ) - for example in examples - ] + if isinstance(example, dict) and KEY_METADATA_EXAMPLE in example.get( + KEY_METADATA, {} + ): + render_as_objects = True + converted_example[KEY_METADATA] = example[KEY_METADATA]["example"] + + converted_examples.append(converted_example) next_item = OrderedDict() next_item[key_name] = entity_key - next_item[key_examples] = LiteralScalarString("".join(converted_examples)) + + if render_as_objects: + rendered_examples = RasaYAMLWriter._render_training_examples_as_objects( + converted_examples + ) + else: + rendered_examples = RasaYAMLWriter._render_training_examples_as_text( + converted_examples + ) + next_item[key_examples] = rendered_examples + result.append(next_item) return result + + @staticmethod + def _render_training_examples_as_objects(examples: List[Dict]) -> List[Dict]: + def render(example: Dict) -> Dict: + value = example[KEY_INTENT_TEXT] + example[KEY_INTENT_TEXT] = LiteralScalarString( + TrainingDataWriter.generate_string_item(value) + ) + return example + + return [render(ex) for ex in examples] + + @staticmethod + def _render_training_examples_as_text(examples: List[Dict]) -> List[Text]: + def render(example: Dict) -> Text: + return TrainingDataWriter.generate_list_item( + example[KEY_INTENT_TEXT].strip(STRIP_SYMBOLS) + ) + + return LiteralScalarString("".join([render(example) for example in examples])) diff --git a/rasa/shared/nlu/training_data/formats/readerwriter.py b/rasa/shared/nlu/training_data/formats/readerwriter.py index 2a538833b929..5e2bba93fe89 100644 --- a/rasa/shared/nlu/training_data/formats/readerwriter.py +++ b/rasa/shared/nlu/training_data/formats/readerwriter.py @@ -69,8 +69,12 @@ def prepare_training_examples(training_data: "TrainingData") -> OrderedDict: @staticmethod def generate_list_item(text: Text) -> Text: """Generates text for a list item.""" + return f"- {TrainingDataWriter.generate_string_item(text)}" - return f"- {rasa.shared.nlu.training_data.util.encode_string(text)}\n" + @staticmethod + def generate_string_item(text: Text) -> Text: + """Generates text for a string item.""" + return f"{rasa.shared.nlu.training_data.util.encode_string(text)}\n" @staticmethod def generate_message(message: Dict[Text, Any]) -> Text: diff --git a/tests/shared/nlu/training_data/formats/test_rasa_yaml.py b/tests/shared/nlu/training_data/formats/test_rasa_yaml.py index 842e504b9ba4..383036b4d43f 100644 --- a/tests/shared/nlu/training_data/formats/test_rasa_yaml.py +++ b/tests/shared/nlu/training_data/formats/test_rasa_yaml.py @@ -66,6 +66,25 @@ - Hello """ +INTENT_EXAMPLES_WITH_METADATA_ROUNDTRIP = f"""version: "{LATEST_TRAINING_DATA_FORMAT_VERSION}" +nlu: +- intent: intent_name + examples: + - text: | + how much CO2 will that use? + metadata: + sentiment: positive + - text: | + how much carbon will a one way flight from [new york]{{"entity": "city", "role": "from"}} to california produce? + metadata: co2-trip-calculation + - text: | + how much CO2 to [new york]{{"entity": "city", "role": "to"}}? +- intent: greet + examples: | + - Hi + - Hello +""" + MINIMAL_VALID_EXAMPLE = """ nlu:\n stories: @@ -177,6 +196,22 @@ def test_intent_with_metadata_is_parsed(): } +def test_metadata_roundtrip(): + reader = RasaYAMLReader() + result = reader.reads(INTENT_EXAMPLES_WITH_METADATA_ROUNDTRIP) + + dumped = RasaYAMLWriter().dumps(result) + assert dumped == INTENT_EXAMPLES_WITH_METADATA_ROUNDTRIP + + validation_reader = RasaYAMLReader() + dumped_result = validation_reader.reads(dumped) + + assert dumped_result.training_examples == result.training_examples + + # dumping again should also not change the format + assert dumped == RasaYAMLWriter().dumps(dumped_result) + + # This test would work only with examples that have a `version` key specified @pytest.mark.parametrize( "example",