diff --git a/changelog/5106.bugfix.rst b/changelog/5106.bugfix.rst new file mode 100644 index 000000000000..eb918520ae4f --- /dev/null +++ b/changelog/5106.bugfix.rst @@ -0,0 +1,2 @@ +Fixed file loading of non proper UTF-8 story files, failing properly when checking for +story files. diff --git a/rasa/data.py b/rasa/data.py index 15a438637624..709233344ba2 100644 --- a/rasa/data.py +++ b/rasa/data.py @@ -146,13 +146,25 @@ def is_story_file(file_path: Text) -> bool: Returns: `True` if it's a story file, otherwise `False`. """ - _is_story_file = False - if file_path.endswith(".md"): - with open(file_path, encoding=DEFAULT_ENCODING) as f: - _is_story_file = any(_contains_story_pattern(l) for l in f) - - return _is_story_file + if not file_path.endswith(".md"): + return False + + try: + with open( + file_path, encoding=DEFAULT_ENCODING, errors="surrogateescape" + ) as lines: + return any(_contains_story_pattern(line) for line in lines) + except Exception as e: + # catch-all because we might be loading files we are not expecting to load + logger.error( + f"Tried to check if '{file_path}' is a story file, but failed to " + f"read it. If this file contains story data, you should " + f"investigate this error, otherwise it is probably best to " + f"move the file to a different location. " + f"Error: {e}" + ) + return False def _contains_story_pattern(text: Text) -> bool: diff --git a/tests/core/test_dsl.py b/tests/core/test_dsl.py index 195e9e7714f6..13a129ba99b4 100644 --- a/tests/core/test_dsl.py +++ b/tests/core/test_dsl.py @@ -25,6 +25,7 @@ MaxHistoryTrackerFeaturizer, BinarySingleStateFeaturizer, ) +from rasa.utils.io import DEFAULT_ENCODING async def test_can_read_test_story(default_domain): @@ -73,7 +74,7 @@ async def test_persist_and_read_test_story_graph(tmpdir, default_domain): "data/test_stories/stories.md", default_domain ) out_path = tmpdir.join("persisted_story.md") - with open(out_path.strpath, "w", encoding="utf-8") as f: + with open(out_path.strpath, "w", encoding=DEFAULT_ENCODING) as f: f.write(graph.as_story_string()) recovered_trackers = await training.load_data( diff --git a/tests/nlu/base/test_training_data.py b/tests/nlu/base/test_training_data.py index 9db3f2d56f8d..935d00354843 100644 --- a/tests/nlu/base/test_training_data.py +++ b/tests/nlu/base/test_training_data.py @@ -268,7 +268,7 @@ def test_repeated_entities(): } }""" with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f: - f.write(data.encode("utf-8")) + f.write(data.encode(io_utils.DEFAULT_ENCODING)) f.flush() td = training_data.load_data(f.name) assert len(td.entity_examples) == 1 @@ -302,7 +302,7 @@ def test_multiword_entities(): } }""" with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f: - f.write(data.encode("utf-8")) + f.write(data.encode(io_utils.DEFAULT_ENCODING)) f.flush() td = training_data.load_data(f.name) assert len(td.entity_examples) == 1 @@ -334,7 +334,7 @@ def test_nonascii_entities(): ] }""" with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f: - f.write(data.encode("utf-8")) + f.write(data.encode(io_utils.DEFAULT_ENCODING)) f.flush() td = training_data.load_data(f.name) assert len(td.entity_examples) == 1 @@ -387,7 +387,7 @@ def test_entities_synonyms(): } }""" with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f: - f.write(data.encode("utf-8")) + f.write(data.encode(io_utils.DEFAULT_ENCODING)) f.flush() td = training_data.load_data(f.name) assert td.entity_synonyms["New York City"] == "nyc" @@ -515,7 +515,9 @@ def test_url_data_format(): } }""" fname = io_utils.create_temporary_file( - data.encode("utf-8"), suffix="_tmp_training_data.json", mode="w+b" + data.encode(io_utils.DEFAULT_ENCODING), + suffix="_tmp_training_data.json", + mode="w+b", ) data = io_utils.read_json_file(fname) assert data is not None diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 000000000000..4dcdb4bb6d82 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,33 @@ +import glob +import os + +from pathlib import Path + +from rasa.data import is_story_file +from rasa.utils.io import write_text_file + + +def test_story_file_can_not_be_yml(tmpdir: Path): + p = tmpdir / "test_non_md.yml" + Path(p).touch() + assert is_story_file(str()) is False + + +def test_empty_story_file_is_not_story_file(tmpdir: Path): + p = tmpdir / "test_non_md.md" + Path(p).touch() + assert is_story_file(str(p)) is False + + +def test_story_file_with_minimal_story_is_story_file(tmpdir: Path): + p = tmpdir / "story.md" + s = """ +## my story + """ + write_text_file(s, p) + assert is_story_file(str(p)) + + +def test_default_story_files_are_story_files(): + for fn in glob.glob(os.path.join("data", "test_stories", "*")): + assert is_story_file(fn) diff --git a/tests/utils/test_io.py b/tests/utils/test_io.py index 11796b2574bf..c870d7f2449f 100644 --- a/tests/utils/test_io.py +++ b/tests/utils/test_io.py @@ -146,7 +146,7 @@ def test_emojis_in_tmp_file(): - two £ (?u)\\b\\w+\\b f\u00fcr """ test_file = io_utils.create_temporary_file(test_data) - with open(test_file, mode="r", encoding="utf-8") as f: + with open(test_file, mode="r", encoding=io_utils.DEFAULT_ENCODING) as f: content = f.read() content = io_utils.read_yaml(content) @@ -280,7 +280,7 @@ def test_list_directory( sub_sub_directory.mkdir() sub_sub_file = sub_sub_directory / "sub_file.txt" - sub_sub_file.write_text("", encoding="utf-8") + sub_sub_file.write_text("", encoding=io_utils.DEFAULT_ENCODING) file1 = subdirectory / "file.txt" file1.write_text("", encoding="utf-8")