Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#7329 load models in finetune mode core #7458

Merged
merged 18 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ def load(
model_server: Optional[EndpointConfig] = None,
remote_storage: Optional[Text] = None,
path_to_model_archive: Optional[Text] = None,
new_config: Optional[Dict] = None,
finetuning_epoch_fraction: float = 1.0,
) -> "Agent":
"""Load a persisted model from the passed path."""
try:
Expand Down Expand Up @@ -441,7 +443,15 @@ def load(

if core_model:
domain = Domain.load(os.path.join(core_model, DEFAULT_DOMAIN_PATH))
ensemble = PolicyEnsemble.load(core_model) if core_model else None
ensemble = (
PolicyEnsemble.load(
core_model,
new_config=new_config,
finetuning_epoch_fraction=finetuning_epoch_fraction,
)
if core_model
else None
)

# ensures the domain hasn't changed between test and train
domain.compare_with_specification(core_model)
Expand Down
60 changes: 55 additions & 5 deletions rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import importlib
import json
import logging
import math
import os
import sys
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Text, Optional, Any, List, Dict, Tuple, NamedTuple, Union
from typing import Text, Optional, Any, List, Dict, Tuple, Type, Union

import rasa.core
import rasa.core.training.training
Expand Down Expand Up @@ -41,6 +42,7 @@
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.generator import TrackerWithCachedStates
from rasa.core import registry
from rasa.utils.tensorflow.constants import EPOCHS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -302,20 +304,68 @@ def _ensure_loaded_policy(cls, policy, policy_cls, policy_name: Text):
"".format(policy_name)
)

@classmethod
def load(cls, path: Union[Text, Path]) -> "PolicyEnsemble":
"""Loads policy and domain specification from storage"""
@staticmethod
def _get_updated_epochs(
policy_cls: Type[Policy],
config_for_policy: Dict[Text, Any],
finetuning_epoch_fraction: float,
) -> Optional[int]:
if EPOCHS in config_for_policy:
epochs = config_for_policy[EPOCHS]
else:
try:
epochs = policy_cls.defaults[EPOCHS]
except (KeyError, AttributeError):
return None
return math.ceil(epochs * finetuning_epoch_fraction)

@classmethod
def load(
cls,
path: Union[Text, Path],
new_config: Optional[Dict] = None,
finetuning_epoch_fraction: float = 1.0,
) -> "PolicyEnsemble":
"""Loads policy and domain specification from disk."""
metadata = cls.load_metadata(path)
cls.ensure_model_compatibility(metadata)
policies = []
for i, policy_name in enumerate(metadata["policy_names"]):
policy_cls = registry.policy_from_module_path(policy_name)
dir_name = f"policy_{i}_{policy_cls.__name__}"
policy_path = os.path.join(path, dir_name)
policy = policy_cls.load(policy_path)

context = {}
if new_config:
context["should_finetune"] = True

config_for_policy = new_config["policies"][i]
epochs = cls._get_updated_epochs(
policy_cls, config_for_policy, finetuning_epoch_fraction
)
if epochs:
context["epoch_override"] = epochs

if "kwargs" not in rasa.shared.utils.common.arguments_of(policy_cls.load):
if context:
raise UnsupportedDialogueModelError(
f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not "
f"accept `**kwargs`. Attempting to pass {context} to the "
f"policy. `**kwargs` should be added to all policies by "
f"Rasa Open Source 3.0.0."
)
else:
rasa.shared.utils.io.raise_deprecation_warning(
f"`{policy_cls.__name__}.{policy_cls.load.__name__}` does not "
f"accept `**kwargs`. `**kwargs` are required for contextual "
f"information e.g. the flag `should_finetune`.",
warn_until_version="3.0.0",
)

policy = policy_cls.load(policy_path, **context)
cls._ensure_loaded_policy(policy, policy_cls, policy_name)
policies.append(policy)

ensemble_cls = rasa.shared.utils.common.class_from_module_path(
metadata["ensemble_name"]
)
Expand Down
4 changes: 3 additions & 1 deletion rasa/core/policies/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ def __init__(
ambiguity_threshold: float = DEFAULT_NLU_FALLBACK_AMBIGUITY_THRESHOLD,
core_threshold: float = DEFAULT_CORE_FALLBACK_THRESHOLD,
fallback_action_name: Text = ACTION_DEFAULT_FALLBACK_NAME,
**kwargs: Any,
) -> None:
"""Create a new Fallback policy.

Args:
priority: Fallback policy priority.
core_threshold: if NLU confidence threshold is met,
predict fallback action with confidence `core_threshold`.
If this is the highest confidence in the ensemble,
Expand All @@ -54,7 +56,7 @@ def __init__(
between confidences of the top two predictions
fallback_action_name: name of the action to execute as a fallback
"""
super().__init__(priority=priority)
super().__init__(priority=priority, **kwargs)

self.nlu_threshold = nlu_threshold
self.ambiguity_threshold = ambiguity_threshold
Expand Down
7 changes: 6 additions & 1 deletion rasa/core/policies/form_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ def __init__(
featurizer: Optional[TrackerFeaturizer] = None,
priority: int = FORM_POLICY_PRIORITY,
lookup: Optional[Dict] = None,
**kwargs: Any,
) -> None:

# max history is set to 2 in order to capture
# previous meaningful action before action listen
super().__init__(
featurizer=featurizer, priority=priority, max_history=2, lookup=lookup
featurizer=featurizer,
priority=priority,
max_history=2,
lookup=lookup,
**kwargs,
)

rasa.shared.utils.io.raise_deprecation_warning(
Expand Down
5 changes: 2 additions & 3 deletions rasa/core/policies/mapping_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ class MappingPolicy(Policy):
def _standard_featurizer() -> None:
return None

def __init__(self, priority: int = MAPPING_POLICY_PRIORITY) -> None:
def __init__(self, priority: int = MAPPING_POLICY_PRIORITY, **kwargs: Any) -> None:
"""Create a new Mapping policy."""

super().__init__(priority=priority)
super().__init__(priority=priority, **kwargs)

rasa.shared.utils.io.raise_deprecation_warning(
f"'{MappingPolicy.__name__}' is deprecated and will be removed in "
Expand Down
4 changes: 2 additions & 2 deletions rasa/core/policies/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
priority: int = MEMOIZATION_POLICY_PRIORITY,
max_history: Optional[int] = MAX_HISTORY_NOT_SET,
lookup: Optional[Dict] = None,
**kwargs: Any,
) -> None:
"""Initialize the policy.

Expand All @@ -81,7 +82,6 @@ def __init__(
lookup: a dictionary that stores featurized tracker states and
predicted actions for them
"""

if max_history == MAX_HISTORY_NOT_SET:
max_history = OLD_DEFAULT_MAX_HISTORY # old default value
rasa.shared.utils.io.raise_warning(
Expand All @@ -97,7 +97,7 @@ def __init__(
if not featurizer:
featurizer = self._standard_featurizer(max_history)

super().__init__(featurizer, priority)
super().__init__(featurizer, priority, **kwargs)

self.max_history = self.featurizer.max_history
self.lookup = lookup if lookup is not None else {}
Expand Down
28 changes: 27 additions & 1 deletion rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
TYPE_CHECKING,
)
import numpy as np

from rasa.core.exceptions import UnsupportedDialogueModelError
from rasa.shared.core.events import Event

import rasa.shared.utils.common
Expand All @@ -34,6 +36,7 @@
from rasa.core.constants import DEFAULT_POLICY_PRIORITY
from rasa.shared.core.constants import USER, SLOTS, PREVIOUS_ACTION, ACTIVE_LOOP
from rasa.shared.nlu.constants import ENTITIES, INTENT, TEXT, ACTION_TEXT, ACTION_NAME
from rasa.utils.tensorflow.constants import EPOCHS

if TYPE_CHECKING:
from rasa.shared.nlu.training_data.features import Features
Expand Down Expand Up @@ -110,12 +113,15 @@ def __init__(
self,
featurizer: Optional[TrackerFeaturizer] = None,
priority: int = DEFAULT_POLICY_PRIORITY,
**kwargs: Any,
joejuzl marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Constructs a new Policy object."""
self.__featurizer = self._create_featurizer(featurizer)
self.priority = priority

@property
def featurizer(self):
"""Returns the policy's featurizer."""
return self.__featurizer

@staticmethod
Expand Down Expand Up @@ -272,7 +278,7 @@ def persist(self, path: Union[Text, Path]) -> None:
rasa.shared.utils.io.dump_obj_as_json_to_file(file, self._metadata())

@classmethod
def load(cls, path: Union[Text, Path]) -> "Policy":
def load(cls, path: Union[Text, Path], **kwargs: Any) -> "Policy":
"""Loads a policy from path.

Args:
Expand All @@ -290,6 +296,26 @@ def load(cls, path: Union[Text, Path]) -> "Policy":
featurizer = TrackerFeaturizer.load(path)
data["featurizer"] = featurizer

if "should_finetune" in kwargs:
data["should_finetune"] = kwargs["should_finetune"]
joejuzl marked this conversation as resolved.
Show resolved Hide resolved

constructor_args = rasa.shared.utils.common.arguments_of(cls)
if "kwargs" not in constructor_args:
if set(data.keys()).issubset(set(constructor_args)):
rasa.shared.utils.io.raise_deprecation_warning(
f"`{cls.__name__}.__init__` does not accept `**kwargs` "
f"This is required for contextual information e.g. the flag "
f"`should_finetune`.",
warn_until_version="3.0.0",
)
else:
raise UnsupportedDialogueModelError(
f"`{cls.__name__}.__init__` does not accept `**kwargs`. "
f"Attempting to pass {data} to the policy. "
f"This argument should be added to all policies by "
f"Rasa Open Source 3.0.0."
)

return cls(**data)

logger.info(
Expand Down
11 changes: 10 additions & 1 deletion rasa/core/policies/rule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
enable_fallback_prediction: bool = True,
restrict_rules: bool = True,
check_for_contradictions: bool = True,
**kwargs: Any,
) -> None:
"""Create a `RulePolicy` object.

Expand All @@ -124,6 +125,10 @@ def __init__(
if no rule matched.
enable_fallback_prediction: If `True` `core_fallback_action_name` is
predicted in case no rule matched.
restrict_rules: If `True` rules are restricted to contain a maximum of 1
user message. This is used to avoid that users build a state machine
using the rules.
check_for_contradictions: Check for contradictions.
"""
self._core_fallback_threshold = core_fallback_threshold
self._fallback_action_name = core_fallback_action_name
Expand All @@ -136,7 +141,11 @@ def __init__(

# max history is set to `None` in order to capture any lengths of rule stories
super().__init__(
featurizer=featurizer, priority=priority, max_history=None, lookup=lookup
featurizer=featurizer,
priority=priority,
max_history=None,
lookup=lookup,
**kwargs,
)

@classmethod
Expand Down
13 changes: 9 additions & 4 deletions rasa/core/policies/sklearn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sklearn.preprocessing import LabelEncoder
from rasa.shared.nlu.constants import ACTION_TEXT, TEXT
from rasa.shared.nlu.training_data.features import Features
from rasa.utils.tensorflow.constants import SENTENCE
from rasa.utils.tensorflow.constants import EPOCHS, SENTENCE
from rasa.utils.tensorflow.model_data import Data

# noinspection PyProtectedMember
Expand Down Expand Up @@ -72,6 +72,8 @@ def __init__(
Args:
featurizer: Featurizer used to convert the training data into
vector format.
priority: Policy priority
max_history: Maximum history of the dialogs.
model: The sklearn model or model pipeline.
param_grid: If *param_grid* is not None and *cv* is given,
a grid search on the given *param_grid* is performed
Expand All @@ -85,7 +87,6 @@ def __init__(
shuffle: Whether to shuffle training data.
zero_state_features: Contains default feature values for attributes
"""

if featurizer:
if not isinstance(featurizer, MaxHistoryTrackerFeaturizer):
raise TypeError(
Expand All @@ -104,7 +105,7 @@ def __init__(
)
featurizer = self._standard_featurizer(max_history)

super().__init__(featurizer, priority)
super().__init__(featurizer, priority, **kwargs)

self.model = model or self._default_model()
self.cv = cv
Expand Down Expand Up @@ -302,7 +303,8 @@ def persist(self, path: Union[Text, Path]) -> None:
)

@classmethod
def load(cls, path: Union[Text, Path]) -> Policy:
def load(cls, path: Union[Text, Path], **kwargs: Any) -> Policy:
"""See the docstring for `Policy.load`."""
filename = Path(path) / "sklearn_model.pkl"
zero_features_filename = Path(path) / "zero_state_features.pkl"
if not Path(path).exists():
Expand All @@ -321,10 +323,13 @@ def load(cls, path: Union[Text, Path]) -> Policy:
meta = json.loads(rasa.shared.utils.io.read_file(meta_file))
zero_state_features = io_utils.pickle_load(zero_features_filename)

data = {"should_finetune": kwargs.get("should_finetune", False)}
joejuzl marked this conversation as resolved.
Show resolved Hide resolved

policy = cls(
featurizer=featurizer,
priority=meta["priority"],
zero_state_features=zero_state_features,
**data,
)

state = io_utils.pickle_load(filename)
Expand Down
10 changes: 7 additions & 3 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,10 @@ def __init__(
**kwargs: Any,
) -> None:
"""Declare instance variables with default values."""

if not featurizer:
featurizer = self._standard_featurizer(max_history)

super().__init__(featurizer, priority)
super().__init__(featurizer, priority, **kwargs)
if isinstance(featurizer, FullDialogueTrackerFeaturizer):
self.is_full_dialogue_featurizer_used = True
else:
Expand Down Expand Up @@ -437,8 +436,9 @@ def persist(self, path: Union[Text, Path]) -> None:
)

@classmethod
def load(cls, path: Union[Text, Path]) -> "TEDPolicy":
def load(cls, path: Union[Text, Path], **kwargs: Any) -> "TEDPolicy":
"""Loads a policy from the storage.

**Needs to load its featurizer**
"""
model_path = Path(path)
Expand Down Expand Up @@ -500,6 +500,10 @@ def load(cls, path: Union[Text, Path]) -> "TEDPolicy":
)
model.build_for_predict(predict_data_example)

meta["should_finetune"] = kwargs.get("should_finetune", False)
if "epoch_override" in kwargs:
meta[EPOCHS] = kwargs["epoch_override"]

Comment on lines +503 to +506
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a larger change which i'll address in my PR as part of making core policies finetunable. It's okay to keep as it is here.

return cls(
featurizer=featurizer,
priority=priority,
Expand Down
Loading