Skip to content

Commit

Permalink
Domain file changes for the 2.0 format
Browse files Browse the repository at this point in the history
- Add possibility to split the domain into separate files
  • Loading branch information
Alexander Khizov committed Jul 8, 2020
1 parent 87f2e9d commit 0b7b3d7
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 93 deletions.
3 changes: 2 additions & 1 deletion .github/scripts/mr_generate_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def generate_json(file, task, data):

return data


def read_results(file):
with open(file) as json_file:
data = json.load(json_file)
Expand All @@ -53,7 +54,7 @@ def read_results(file):
if f not in task_mapping.keys():
continue

data = generate_json(os.path.join(dirpath, f),task_mapping[f], data)
data = generate_json(os.path.join(dirpath, f), task_mapping[f], data)

with open(SUMMARY_FILE, "w") as f:
json.dump(data, f, sort_keys=True, indent=2)
5 changes: 5 additions & 0 deletions changelog/6132.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Added possibility to split the domain into separate files. All YAML files under the path specified with ``--domain`` will be scanned for domain information (e.g. intents, actions, etc) and then combined into a single domain.

The default value for ``--domain`` is still ``domain.yml``.

Also, the default session expiration time is set to 60 minutes now.
4 changes: 3 additions & 1 deletion rasa/cli/arguments/default_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def add_domain_param(
"--domain",
type=str,
default=DEFAULT_DOMAIN_PATH,
help="Domain specification (yml file).",
help="Domain specification. It can be a single 'yaml' file, or a directory "
"that contains several files with domain specification in it. The content "
"of these files will be read and merged together.",
)


Expand Down
2 changes: 1 addition & 1 deletion rasa/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
ENV_SANIC_WORKERS = "SANIC_WORKERS"
ENV_SANIC_BACKLOG = "SANIC_BACKLOG"

DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES = 60
DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES = 0
DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION = True

ENV_GPU_CONFIG = "TF_GPU_MEMORY_ALLOC"
Expand Down
123 changes: 80 additions & 43 deletions rasa/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from pathlib import Path
from typing import Any, Dict, List, NamedTuple, Optional, Set, Text, Tuple, Union

from ruamel.yaml import YAMLError

import rasa.core.constants
from rasa.utils.common import (
raise_warning,
Expand All @@ -19,6 +21,7 @@
DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION,
DOMAIN_SCHEMA_FILE,
DOCS_URL_DOMAINS,
DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES,
)
from rasa.core import utils
from rasa.core.actions import action # pytype: disable=pyi-error
Expand Down Expand Up @@ -47,6 +50,23 @@
USE_ENTITIES_KEY = "use_entities"
IGNORE_ENTITIES_KEY = "ignore_entities"

KEY_SLOTS = "slots"
KEY_INTENTS = "intents"
KEY_ENTITIES = "entities"
KEY_RESPONSES = "responses"
KEY_ACTIONS = "actions"
KEY_FORMS = "forms"

ALL_DOMAIN_KEYS = [
KEY_SLOTS,
KEY_FORMS,
KEY_ACTIONS,
KEY_ENTITIES,
KEY_INTENTS,
KEY_RESPONSES,
]


if typing.TYPE_CHECKING:
from rasa.core.trackers import DialogueStateTracker

Expand All @@ -69,7 +89,10 @@ class SessionConfig(NamedTuple):
@staticmethod
def default() -> "SessionConfig":
# TODO: 2.0, reconsider how to apply sessions to old projects
return SessionConfig(0, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION)
return SessionConfig(
DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES,
DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION,
)

def are_sessions_enabled(self) -> bool:
return self.session_expiration_time > 0
Expand Down Expand Up @@ -134,7 +157,7 @@ def from_yaml(cls, yaml: Text) -> "Domain":

@classmethod
def from_dict(cls, data: Dict) -> "Domain":
utter_templates = cls.collect_templates(data.get("responses", {}))
utter_templates = cls.collect_templates(data.get(KEY_RESPONSES, {}))
if "templates" in data:
raise_warning(
"Your domain file contains the key: 'templates'. This has been "
Expand All @@ -146,54 +169,45 @@ def from_dict(cls, data: Dict) -> "Domain":
)
utter_templates = cls.collect_templates(data.get("templates", {}))

slots = cls.collect_slots(data.get("slots", {}))
slots = cls.collect_slots(data.get(KEY_SLOTS, {}))
additional_arguments = data.get("config", {})
session_config = cls._get_session_config(data.get(SESSION_CONFIG_KEY, {}))
intents = data.get("intents", {})
intents = data.get(KEY_INTENTS, {})

return cls(
intents,
data.get("entities", []),
data.get(KEY_ENTITIES, []),
slots,
utter_templates,
data.get("actions", []),
data.get("forms", []),
data.get(KEY_ACTIONS, []),
data.get(KEY_FORMS, []),
session_config=session_config,
**additional_arguments,
)

@staticmethod
def _get_session_config(session_config: Dict) -> SessionConfig:
session_expiration_time = session_config.get(SESSION_EXPIRATION_TIME_KEY)
session_expiration_time_min = session_config.get(SESSION_EXPIRATION_TIME_KEY)

# TODO: 2.0 reconsider how to apply sessions to old projects and legacy trackers
if session_expiration_time is None:
raise_warning(
"No tracker session configuration was found in the loaded domain. "
"Domains without a session config will automatically receive a "
"session expiration time of 60 minutes in Rasa version 2.0 if not "
"configured otherwise.",
FutureWarning,
docs=DOCS_URL_DOMAINS + "#session-configuration",
)
session_expiration_time = 0
if session_expiration_time_min is None:
session_expiration_time_min = DEFAULT_SESSION_EXPIRATION_TIME_IN_MINUTES

carry_over_slots = session_config.get(
CARRY_OVER_SLOTS_KEY, DEFAULT_CARRY_OVER_SLOTS_TO_NEW_SESSION
)

return SessionConfig(session_expiration_time, carry_over_slots)
return SessionConfig(session_expiration_time_min, carry_over_slots)

@classmethod
def from_directory(cls, path: Text) -> "Domain":
"""Loads and merges multiple domain files recursively from a directory tree."""
from rasa import data

domain = Domain.empty()
for root, _, files in os.walk(path, followlinks=True):
for file in files:
full_path = os.path.join(root, file)
if data.is_domain_file(full_path):
if Domain.is_domain_file(full_path):
other = Domain.from_file(full_path)
domain = other.merge(domain)

Expand Down Expand Up @@ -236,20 +250,20 @@ def merge_lists(l1: List[Any], l2: List[Any]) -> List[Any]:
combined[SESSION_CONFIG_KEY] = domain_dict[SESSION_CONFIG_KEY]

# intents is list of dicts
intents_1 = {list(i.keys())[0]: i for i in combined["intents"]}
intents_2 = {list(i.keys())[0]: i for i in domain_dict["intents"]}
intents_1 = {list(i.keys())[0]: i for i in combined[KEY_INTENTS]}
intents_2 = {list(i.keys())[0]: i for i in domain_dict[KEY_INTENTS]}
merged_intents = merge_dicts(intents_1, intents_2, override)
combined["intents"] = list(merged_intents.values())
combined[KEY_INTENTS] = list(merged_intents.values())

# remove existing forms from new actions
for form in combined["forms"]:
if form in domain_dict["actions"]:
domain_dict["actions"].remove(form)
for form in combined[KEY_FORMS]:
if form in domain_dict[KEY_ACTIONS]:
domain_dict[KEY_ACTIONS].remove(form)

for key in ["entities", "actions", "forms"]:
for key in [KEY_ENTITIES, KEY_ACTIONS, KEY_FORMS]:
combined[key] = merge_lists(combined[key], domain_dict[key])

for key in ["responses", "slots"]:
for key in [KEY_RESPONSES, KEY_SLOTS]:
combined[key] = merge_dicts(combined[key], domain_dict[key], override)

return self.__class__.from_dict(combined)
Expand Down Expand Up @@ -431,8 +445,8 @@ def __init__(
def __hash__(self) -> int:

self_as_dict = self.as_dict()
self_as_dict["intents"] = sort_list_of_dicts_by_first_key(
self_as_dict["intents"]
self_as_dict[KEY_INTENTS] = sort_list_of_dicts_by_first_key(
self_as_dict[KEY_INTENTS]
)
self_as_string = json.dumps(self_as_dict, sort_keys=True)
text_hash = utils.get_text_hash(self_as_string)
Expand Down Expand Up @@ -774,12 +788,12 @@ def as_dict(self) -> Dict[Text, Any]:
SESSION_EXPIRATION_TIME_KEY: self.session_config.session_expiration_time,
CARRY_OVER_SLOTS_KEY: self.session_config.carry_over_slots,
},
"intents": self._transform_intents_for_file(),
"entities": self.entities,
"slots": self._slot_definitions(),
"responses": self.templates,
"actions": self.user_actions, # class names of the actions
"forms": self.form_names,
KEY_INTENTS: self._transform_intents_for_file(),
KEY_ENTITIES: self.entities,
KEY_SLOTS: self._slot_definitions(),
KEY_RESPONSES: self.templates,
KEY_ACTIONS: self.user_actions, # class names of the actions
KEY_FORMS: self.form_names,
}

def persist(self, filename: Union[Text, Path]) -> None:
Expand Down Expand Up @@ -827,16 +841,16 @@ def cleaned_domain(self) -> Dict[Text, Any]:
"""
domain_data = self.as_dict()

for idx, intent_info in enumerate(domain_data["intents"]):
for idx, intent_info in enumerate(domain_data[KEY_INTENTS]):
for name, intent in intent_info.items():
if intent.get(USE_ENTITIES_KEY) is True:
del intent[USE_ENTITIES_KEY]
if not intent.get(IGNORE_ENTITIES_KEY):
intent.pop(IGNORE_ENTITIES_KEY, None)
if len(intent) == 0:
domain_data["intents"][idx] = name
domain_data[KEY_INTENTS][idx] = name

for slot in domain_data["slots"].values(): # pytype: disable=attribute-error
for slot in domain_data[KEY_SLOTS].values(): # pytype: disable=attribute-error
if slot["initial_value"] is None:
del slot["initial_value"]
if slot["auto_fill"]:
Expand Down Expand Up @@ -1040,9 +1054,9 @@ def get_duplicate_exception_message(
raise InvalidDomain(
get_exception_message(
[
(duplicate_actions, "actions"),
(duplicate_slots, "slots"),
(duplicate_entities, "entities"),
(duplicate_actions, KEY_ACTIONS),
(duplicate_slots, KEY_SLOTS),
(duplicate_entities, KEY_ENTITIES),
],
incorrect_mappings,
)
Expand Down Expand Up @@ -1074,6 +1088,29 @@ def is_empty(self) -> bool:

return self.as_dict() == Domain.empty().as_dict()

@staticmethod
def is_domain_file(filename: Text) -> bool:
"""Checks whether the given file path is a Rasa domain file.
Args:
filename: Path of the file which should be checked.
Returns:
`True` if it's a domain file, otherwise `False`.
"""
from rasa.data import YAML_FILE_EXTENSIONS

if not Path(filename).suffix in YAML_FILE_EXTENSIONS:
return False
try:
content = rasa.utils.io.read_yaml_file(filename)
if any(key in content for key in ALL_DOMAIN_KEYS):
return True
except YAMLError:
pass

return False


class TemplateDomain(Domain):
pass
15 changes: 0 additions & 15 deletions rasa/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,6 @@ def is_end_to_end_conversation_test_file(file_path: Text) -> bool:
)


def is_domain_file(file_path: Text) -> bool:
"""Checks whether the given file path is a Rasa domain file.
Args:
file_path: Path of the file which should be checked.
Returns:
`True` if it's a domain file, otherwise `False`.
"""

file_name = os.path.basename(file_path)

return file_name in ["domain.yml", "domain.yaml"]


def is_config_file(file_path: Text) -> bool:
"""Checks whether the given file path is a Rasa config file.
Expand Down
2 changes: 1 addition & 1 deletion rasa/importers/multi_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _init_from_directory(self, path: Text):

if data.is_end_to_end_conversation_test_file(full_path):
self._e2e_story_paths.append(full_path)
elif data.is_domain_file(full_path):
elif Domain.is_domain_file(full_path):
self._domain_paths.append(full_path)
elif data.is_nlu_file(full_path):
self._nlu_paths.append(full_path)
Expand Down
Loading

0 comments on commit 0b7b3d7

Please sign in to comment.