Skip to content

Commit

Permalink
Add possibility to split the domain into separate files
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Khizov committed Jul 7, 2020
1 parent 87f2e9d commit d50d039
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 38 deletions.
1 change: 1 addition & 0 deletions changelog/6132.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add possibility to split the domain into separate files. Domain file now doesn't have to be called `domain.yaml` anymore.
101 changes: 70 additions & 31 deletions rasa/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Set, Text, Tuple, Union

from ruamel.yaml import YAMLError

import rasa.core.constants
from rasa.data import YAML_FILE_EXTENSIONS
from rasa.utils.common import (
raise_warning,
lazy_property,
Expand Down Expand Up @@ -47,6 +50,23 @@
USE_ENTITIES_KEY = "use_entities"
IGNORE_ENTITIES_KEY = "ignore_entities"

KEY_SLOTS = "slots"
KEY_INTENTS = "intents"
KEY_ENTITIES = "entities"
KEY_RESPONSES = "responses"
KEY_ACTIONS = "actions"
KEY_FORMS = "forms"

ALL_DOMAIN_KEYS = [
KEY_SLOTS,
KEY_FORMS,
KEY_ACTIONS,
KEY_ENTITIES,
KEY_INTENTS,
KEY_RESPONSES,
]


if typing.TYPE_CHECKING:
from rasa.core.trackers import DialogueStateTracker

Expand Down Expand Up @@ -133,8 +153,9 @@ def from_yaml(cls, yaml: Text) -> "Domain":
return cls.from_dict(data)

@classmethod
def from_dict(cls, data: Dict) -> "Domain":
utter_templates = cls.collect_templates(data.get("responses", {}))
def from_dict(cls, data: Dict) -> Optional["Domain"]:

utter_templates = cls.collect_templates(data.get(KEY_RESPONSES, {}))
if "templates" in data:
raise_warning(
"Your domain file contains the key: 'templates'. This has been "
Expand All @@ -146,18 +167,18 @@ def from_dict(cls, data: Dict) -> "Domain":
)
utter_templates = cls.collect_templates(data.get("templates", {}))

slots = cls.collect_slots(data.get("slots", {}))
slots = cls.collect_slots(data.get(KEY_SLOTS, {}))
additional_arguments = data.get("config", {})
session_config = cls._get_session_config(data.get(SESSION_CONFIG_KEY, {}))
intents = data.get("intents", {})
intents = data.get(KEY_INTENTS, {})

return cls(
intents,
data.get("entities", []),
data.get(KEY_ENTITIES, []),
slots,
utter_templates,
data.get("actions", []),
data.get("forms", []),
data.get(KEY_ACTIONS, []),
data.get(KEY_FORMS, []),
session_config=session_config,
**additional_arguments,
)
Expand Down Expand Up @@ -187,13 +208,12 @@ def _get_session_config(session_config: Dict) -> SessionConfig:
@classmethod
def from_directory(cls, path: Text) -> "Domain":
"""Loads and merges multiple domain files recursively from a directory tree."""
from rasa import data

domain = Domain.empty()
for root, _, files in os.walk(path, followlinks=True):
for file in files:
full_path = os.path.join(root, file)
if data.is_domain_file(full_path):
if Domain.is_domain_file(full_path):
other = Domain.from_file(full_path)
domain = other.merge(domain)

Expand Down Expand Up @@ -236,20 +256,20 @@ def merge_lists(l1: List[Any], l2: List[Any]) -> List[Any]:
combined[SESSION_CONFIG_KEY] = domain_dict[SESSION_CONFIG_KEY]

# intents is list of dicts
intents_1 = {list(i.keys())[0]: i for i in combined["intents"]}
intents_2 = {list(i.keys())[0]: i for i in domain_dict["intents"]}
intents_1 = {list(i.keys())[0]: i for i in combined[KEY_INTENTS]}
intents_2 = {list(i.keys())[0]: i for i in domain_dict[KEY_INTENTS]}
merged_intents = merge_dicts(intents_1, intents_2, override)
combined["intents"] = list(merged_intents.values())
combined[KEY_INTENTS] = list(merged_intents.values())

# remove existing forms from new actions
for form in combined["forms"]:
if form in domain_dict["actions"]:
domain_dict["actions"].remove(form)
for form in combined[KEY_FORMS]:
if form in domain_dict[KEY_ACTIONS]:
domain_dict[KEY_ACTIONS].remove(form)

for key in ["entities", "actions", "forms"]:
for key in [KEY_ENTITIES, KEY_ACTIONS, KEY_FORMS]:
combined[key] = merge_lists(combined[key], domain_dict[key])

for key in ["responses", "slots"]:
for key in [KEY_RESPONSES, KEY_SLOTS]:
combined[key] = merge_dicts(combined[key], domain_dict[key], override)

return self.__class__.from_dict(combined)
Expand Down Expand Up @@ -431,8 +451,8 @@ def __init__(
def __hash__(self) -> int:

self_as_dict = self.as_dict()
self_as_dict["intents"] = sort_list_of_dicts_by_first_key(
self_as_dict["intents"]
self_as_dict[KEY_INTENTS] = sort_list_of_dicts_by_first_key(
self_as_dict[KEY_INTENTS]
)
self_as_string = json.dumps(self_as_dict, sort_keys=True)
text_hash = utils.get_text_hash(self_as_string)
Expand Down Expand Up @@ -774,12 +794,12 @@ def as_dict(self) -> Dict[Text, Any]:
SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time,
CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots,
},
"intents": self._transform_intents_for_file(),
"entities": self.entities,
"slots": self._slot_definitions(),
"responses": self.templates,
"actions": self.user_actions, # class names of the actions
"forms": self.form_names,
KEY_INTENTS: self._transform_intents_for_file(),
KEY_ENTITIES: self.entities,
KEY_SLOTS: self._slot_definitions(),
KEY_RESPONSES: self.templates,
KEY_ACTIONS: self.user_actions, # class names of the actions
KEY_FORMS: self.form_names,
}

def persist(self, filename: Union[Text, Path]) -> None:
Expand Down Expand Up @@ -827,16 +847,16 @@ def cleaned_domain(self) -> Dict[Text, Any]:
"""
domain_data = self.as_dict()

for idx, intent_info in enumerate(domain_data["intents"]):
for idx, intent_info in enumerate(domain_data[KEY_INTENTS]):
for name, intent in intent_info.items():
if intent.get(USE_ENTITIES_KEY) is True:
del intent[USE_ENTITIES_KEY]
if not intent.get(IGNORE_ENTITIES_KEY):
intent.pop(IGNORE_ENTITIES_KEY, None)
if len(intent) == 0:
domain_data["intents"][idx] = name
domain_data[KEY_INTENTS][idx] = name

for slot in domain_data["slots"].values(): # pytype: disable=attribute-error
for slot in domain_data[KEY_SLOTS].values(): # pytype: disable=attribute-error
if slot["initial_value"] is None:
del slot["initial_value"]
if slot["auto_fill"]:
Expand Down Expand Up @@ -1040,9 +1060,9 @@ def get_duplicate_exception_message(
raise InvalidDomain(
get_exception_message(
[
(duplicate_actions, "actions"),
(duplicate_slots, "slots"),
(duplicate_entities, "entities"),
(duplicate_actions, KEY_ACTIONS),
(duplicate_slots, KEY_SLOTS),
(duplicate_entities, KEY_ENTITIES),
],
incorrect_mappings,
)
Expand Down Expand Up @@ -1074,6 +1094,25 @@ def is_empty(self) -> bool:

return self.as_dict() == Domain.empty().as_dict()

@staticmethod
def is_domain_file(filename: Text) -> bool:
"""Checks whether the given file path is a Rasa domain file.
Args:
filename: Path of the file which should be checked.
Returns:
`True` if it's a domain file, otherwise `False`.
"""
if not Path(filename).suffix in YAML_FILE_EXTENSIONS:
return False
try:
content = rasa.utils.io.read_yaml_file(filename)
if any(key in content for key in ALL_DOMAIN_KEYS):
return True
except YAMLError:
pass


class TemplateDomain(Domain):
pass
7 changes: 0 additions & 7 deletions tests/core/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,6 @@ def test_load_domain_from_directory_tree(tmpdir_factory: TempdirFactory):
# Check if loading from `.yaml` also works
utils.dump_obj_as_yaml_to_file(subsubdirectory / "domain.yaml", skill_2_1_domain)

subsubdirectory_2 = subdirectory_2 / "Skill 2-2"
subsubdirectory_2.mkdir()
excluded_domain = {"actions": ["should not be loaded"]}
utils.dump_obj_as_yaml_to_file(
subsubdirectory_2 / "other_name.yaml", excluded_domain
)

actual = Domain.load(str(root))
expected = [
"utter_root",
Expand Down

0 comments on commit d50d039

Please sign in to comment.