Skip to content

Commit

Permalink
improved utf-8 file loading. #5106
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Feb 10, 2020
1 parent a73989c commit 01bc984
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 14 deletions.
2 changes: 2 additions & 0 deletions changelog/5106.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed file loading of non proper UTF-8 story files, failing properly when checking for
story files.
24 changes: 18 additions & 6 deletions rasa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,25 @@ def is_story_file(file_path: Text) -> bool:
Returns:
`True` if it's a story file, otherwise `False`.
"""
_is_story_file = False

if file_path.endswith(".md"):
with open(file_path, encoding=DEFAULT_ENCODING) as f:
_is_story_file = any(_contains_story_pattern(l) for l in f)

return _is_story_file
if not file_path.endswith(".md"):
return False

try:
with open(
file_path, encoding=DEFAULT_ENCODING, errors="surrogateescape"
) as lines:
return any(_contains_story_pattern(line) for line in lines)
except Exception as e:
# catch-all because we might be loading files we are not expecting to load
logger.error(
f"Tried to check if '{file_path}' is a story file, but failed to "
f"read it. If this file contains story data, you should "
f"investigate this error, otherwise it is probably best to "
f"move the file to a different location. "
f"Error: {e}"
)
return False


def _contains_story_pattern(text: Text) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MaxHistoryTrackerFeaturizer,
BinarySingleStateFeaturizer,
)
from rasa.utils.io import DEFAULT_ENCODING


async def test_can_read_test_story(default_domain):
Expand Down Expand Up @@ -73,7 +74,7 @@ async def test_persist_and_read_test_story_graph(tmpdir, default_domain):
"data/test_stories/stories.md", default_domain
)
out_path = tmpdir.join("persisted_story.md")
with open(out_path.strpath, "w", encoding="utf-8") as f:
with open(out_path.strpath, "w", encoding=DEFAULT_ENCODING) as f:
f.write(graph.as_story_string())

recovered_trackers = await training.load_data(
Expand Down
12 changes: 7 additions & 5 deletions tests/nlu/base/test_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_repeated_entities():
}
}"""
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
f.write(data.encode("utf-8"))
f.write(data.encode(io_utils.DEFAULT_ENCODING))
f.flush()
td = training_data.load_data(f.name)
assert len(td.entity_examples) == 1
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_multiword_entities():
}
}"""
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
f.write(data.encode("utf-8"))
f.write(data.encode(io_utils.DEFAULT_ENCODING))
f.flush()
td = training_data.load_data(f.name)
assert len(td.entity_examples) == 1
Expand Down Expand Up @@ -334,7 +334,7 @@ def test_nonascii_entities():
]
}"""
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
f.write(data.encode("utf-8"))
f.write(data.encode(io_utils.DEFAULT_ENCODING))
f.flush()
td = training_data.load_data(f.name)
assert len(td.entity_examples) == 1
Expand Down Expand Up @@ -387,7 +387,7 @@ def test_entities_synonyms():
}
}"""
with tempfile.NamedTemporaryFile(suffix="_tmp_training_data.json") as f:
f.write(data.encode("utf-8"))
f.write(data.encode(io_utils.DEFAULT_ENCODING))
f.flush()
td = training_data.load_data(f.name)
assert td.entity_synonyms["New York City"] == "nyc"
Expand Down Expand Up @@ -515,7 +515,9 @@ def test_url_data_format():
}
}"""
fname = io_utils.create_temporary_file(
data.encode("utf-8"), suffix="_tmp_training_data.json", mode="w+b"
data.encode(io_utils.DEFAULT_ENCODING),
suffix="_tmp_training_data.json",
mode="w+b",
)
data = io_utils.read_json_file(fname)
assert data is not None
Expand Down
33 changes: 33 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import glob
import os

from pathlib import Path

from rasa.data import is_story_file
from rasa.utils.io import write_text_file


def test_story_file_can_not_be_yml(tmpdir: Path):
p = tmpdir / "test_non_md.yml"
Path(p).touch()
assert is_story_file(str()) is False


def test_empty_story_file_is_not_story_file(tmpdir: Path):
p = tmpdir / "test_non_md.md"
Path(p).touch()
assert is_story_file(str(p)) is False


def test_story_file_with_minimal_story_is_story_file(tmpdir: Path):
p = tmpdir / "story.md"
s = """
## my story
"""
write_text_file(s, p)
assert is_story_file(str(p))


def test_default_story_files_are_story_files():
for fn in glob.glob(os.path.join("data", "test_stories", "*")):
assert is_story_file(fn)
4 changes: 2 additions & 2 deletions tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_emojis_in_tmp_file():
- two £ (?u)\\b\\w+\\b f\u00fcr
"""
test_file = io_utils.create_temporary_file(test_data)
with open(test_file, mode="r", encoding="utf-8") as f:
with open(test_file, mode="r", encoding=io_utils.DEFAULT_ENCODING) as f:
content = f.read()
content = io_utils.read_yaml(content)

Expand Down Expand Up @@ -280,7 +280,7 @@ def test_list_directory(
sub_sub_directory.mkdir()

sub_sub_file = sub_sub_directory / "sub_file.txt"
sub_sub_file.write_text("", encoding="utf-8")
sub_sub_file.write_text("", encoding=io_utils.DEFAULT_ENCODING)

file1 = subdirectory / "file.txt"
file1.write_text("", encoding="utf-8")
Expand Down

0 comments on commit 01bc984

Please sign in to comment.