Skip to content

Commit

Permalink
removed unused continue training
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Feb 20, 2020
1 parent 8d19242 commit 8e15ae8
Show file tree
Hide file tree
Showing 15 changed files with 17 additions and 207 deletions.
2 changes: 2 additions & 0 deletions changelog/4991.removal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Removed ``Agent.continue_training`` and the ``dump_flattened_stories`` parameter
from ``Agent.persist``.
2 changes: 0 additions & 2 deletions rasa/cli/arguments/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
add_config_param,
add_out_param,
add_debug_plots_param,
add_dump_stories_param,
add_augmentation_param,
add_persist_nlu_data_param,
)
Expand Down Expand Up @@ -75,6 +74,5 @@ def _add_training_arguments(parser: argparse.ArgumentParser) -> argparse._Argume
)
add_augmentation_param(train_arguments)
add_debug_plots_param(train_arguments)
add_dump_stories_param(train_arguments)

return train_arguments
13 changes: 0 additions & 13 deletions rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def set_train_arguments(parser: argparse.ArgumentParser):

add_augmentation_param(parser)
add_debug_plots_param(parser)
add_dump_stories_param(parser)

add_model_name_param(parser)
add_persist_nlu_data_param(parser)
Expand All @@ -34,7 +33,6 @@ def set_train_core_arguments(parser: argparse.ArgumentParser):

add_augmentation_param(parser)
add_debug_plots_param(parser)
add_dump_stories_param(parser)

add_force_param(parser)

Expand Down Expand Up @@ -109,17 +107,6 @@ def add_augmentation_param(
)


def add_dump_stories_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
):
parser.add_argument(
"--dump-stories",
default=False,
action="store_true",
help="If enabled, save flattened stories to a file.",
)


def add_debug_plots_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
):
Expand Down
2 changes: 0 additions & 2 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict:

if "augmentation" in args:
arguments["augmentation_factor"] = args.augmentation
if "dump_stories" in args:
arguments["dump_stories"] = args.dump_stories
if "debug_plots" in args:
arguments["debug_plots"] = args.debug_plots

Expand Down
27 changes: 2 additions & 25 deletions rasa/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,21 +590,6 @@ def toggle_memoization(self, activate: bool) -> None:
if type(p) == MemoizationPolicy:
p.toggle(activate)

def continue_training(
self, trackers: List[DialogueStateTracker], **kwargs: Any
) -> None:
raise_warning(
"Continue training will be removed in the 2.0 release. It won't be "
"possible to continue the training, you should probably retrain instead.",
FutureWarning,
)

if not self.is_core_ready():
raise AgentNotReady("Can't continue training without a policy ensemble.")

self.policy_ensemble.continue_training(trackers, self.domain, **kwargs)
self._set_fingerprint()

def _max_history(self) -> int:
"""Find maximum max_history."""

Expand Down Expand Up @@ -794,17 +779,9 @@ def _clear_model_directory(model_path: Text) -> None:
"overwritten.".format(model_path)
)

def persist(self, model_path: Text, dump_flattened_stories: bool = False) -> None:
def persist(self, model_path: Text) -> None:
"""Persists this agent into a directory for later loading and usage."""

if dump_flattened_stories:
raise_warning(
"The `dump_flattened_stories` argument will be removed from "
"`Agent.persist` in the 2.0 release. Please dump your "
"training data separately if you need it to be part of the model.",
FutureWarning,
)

if not self.is_core_ready():
raise AgentNotReady("Can't persist without a policy ensemble.")

Expand All @@ -813,7 +790,7 @@ def persist(self, model_path: Text, dump_flattened_stories: bool = False) -> Non

self._clear_model_directory(model_path)

self.policy_ensemble.persist(model_path, dump_flattened_stories)
self.policy_ensemble.persist(model_path)
self.domain.persist(os.path.join(model_path, DEFAULT_DOMAIN_PATH))
self.domain.persist_specification(model_path)

Expand Down
36 changes: 0 additions & 36 deletions rasa/core/policies/embedding_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,42 +506,6 @@ def train(
self.attention_weights
)

def continue_training(
self,
training_trackers: List["DialogueStateTracker"],
domain: "Domain",
**kwargs: Any,
) -> None:
"""Continue training an already trained policy."""

batch_size = kwargs.get("batch_size", 5)
epochs = kwargs.get("epochs", 50)

with self.graph.as_default():
for _ in range(epochs):
training_data = self._training_data_for_continue_training(
batch_size, training_trackers, domain
)

session_data = self._create_session_data(
training_data.X, training_data.y
)
train_dataset = train_utils.create_tf_dataset(
session_data, batch_size, label_key="action_ids"
)
train_init_op = self._iterator.make_initializer(train_dataset)
self.session.run(train_init_op)

# fit to one extra example using updated trackers
while True:
try:
self.session.run(
self._train_op, feed_dict={self._is_training: True}
)

except tf.errors.OutOfRangeError:
break

def tf_feed_dict_for_prediction(
self, tracker: "DialogueStateTracker", domain: "Domain"
) -> Dict["tf.Tensor", "np.ndarray"]:
Expand Down
35 changes: 8 additions & 27 deletions rasa/core/policies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import rasa.utils.io
from rasa.constants import MINIMUM_COMPATIBLE_VERSION, DOCS_BASE_URL, DOCS_URL_POLICIES

from rasa.core import utils, training
from rasa.core import utils
from rasa.core.constants import USER_INTENT_BACK, USER_INTENT_RESTART
from rasa.core.actions.action import (
ACTION_LISTEN_NAME,
Expand Down Expand Up @@ -39,7 +39,6 @@ def __init__(
self, policies: List[Policy], action_fingerprints: Optional[Dict] = None
) -> None:
self.policies = policies
self.training_trackers = None
self.date_trained = None

if action_fingerprints:
Expand Down Expand Up @@ -123,9 +122,11 @@ def train(
if training_trackers:
for policy in self.policies:
policy.train(training_trackers, domain, **kwargs)

training_events = self._training_events_from_trackers(training_trackers)
self.action_fingerprints = self._create_action_fingerprints(training_events)
else:
logger.info("Skipped training, because there are no training samples.")
self.training_trackers = training_trackers
self.date_trained = datetime.now().strftime("%Y%m%d-%H%M%S")

def probabilities_using_best_policy(
Expand Down Expand Up @@ -172,24 +173,17 @@ def _add_package_version_info(self, metadata: Dict[Text, Any]) -> None:
except ImportError:
pass

def _persist_metadata(
self, path: Text, dump_flattened_stories: bool = False
) -> None:
def _persist_metadata(self, path: Text) -> None:
"""Persists the domain specification to storage."""

# make sure the directory we persist exists
domain_spec_path = os.path.join(path, "metadata.json")
training_data_path = os.path.join(path, "stories.md")
rasa.utils.io.create_directory_for_file(domain_spec_path)

policy_names = [utils.module_path_from_instance(p) for p in self.policies]

training_events = self._training_events_from_trackers(self.training_trackers)

action_fingerprints = self._create_action_fingerprints(training_events)

metadata = {
"action_fingerprints": action_fingerprints,
"action_fingerprints": self.action_fingerprints,
"python": ".".join([str(s) for s in sys.version_info[:3]]),
"max_histories": self._max_histories(),
"ensemble_name": self.__module__ + "." + self.__class__.__name__,
Expand All @@ -201,15 +195,10 @@ def _persist_metadata(

rasa.utils.io.dump_obj_as_json_to_file(domain_spec_path, metadata)

# if there are lots of stories, saving flattened stories takes a long
# time, so this is turned off by default
if dump_flattened_stories:
training.persist_data(self.training_trackers, training_data_path)

def persist(self, path: Text, dump_flattened_stories: bool = False) -> None:
def persist(self, path: Text) -> None:
"""Persists the policy to storage."""

self._persist_metadata(path, dump_flattened_stories)
self._persist_metadata(path)

for i, policy in enumerate(self.policies):
dir_name = "policy_{}_{}".format(i, type(policy).__name__)
Expand Down Expand Up @@ -356,14 +345,6 @@ def get_state_featurizer_from_dict(cls, featurizer_config) -> Tuple[Any, Any]:

return state_featurizer_func, state_featurizer_config

def continue_training(
self, trackers: List[DialogueStateTracker], domain: Domain, **kwargs: Any
) -> None:

self.training_trackers.extend(trackers)
for p in self.policies:
p.continue_training(self.training_trackers, domain, **kwargs)


class SimplePolicyEnsemble(PolicyEnsemble):
@staticmethod
Expand Down
35 changes: 0 additions & 35 deletions rasa/core/policies/keras_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,41 +206,6 @@ def train(
self.current_epoch = self.defaults.get("epochs", 1)
logger.info("Done fitting keras policy model")

def continue_training(
self,
training_trackers: List[DialogueStateTracker],
domain: Domain,
**kwargs: Any,
) -> None:
"""Continues training an already trained policy."""

# takes the new example labelled and learns it
# via taking `epochs` samples of n_batch-1 parts of the training data,
# inserting our new example and learning them. this means that we can
# ask the network to fit the example without overemphasising
# its importance (and therefore throwing off the biases)

batch_size = kwargs.get("batch_size", 5)
epochs = kwargs.get("epochs", 50)

with self.graph.as_default(), self.session.as_default():
for _ in range(epochs):
training_data = self._training_data_for_continue_training(
batch_size, training_trackers, domain
)

# fit to one extra example using updated trackers
self.model.fit(
training_data.X,
training_data.y,
epochs=self.current_epoch + 1,
batch_size=len(training_data.y),
verbose=obtain_verbosity(),
initial_epoch=self.current_epoch,
)

self.current_epoch += 1

def predict_action_probabilities(
self, tracker: DialogueStateTracker, domain: Domain
) -> List[float]:
Expand Down
14 changes: 0 additions & 14 deletions rasa/core/policies/memoization.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,6 @@ def train(
self._add_states_to_lookup(trackers_as_states, trackers_as_actions, domain)
logger.debug("Memorized {} unique examples.".format(len(self.lookup)))

def continue_training(
self,
training_trackers: List[DialogueStateTracker],
domain: Domain,
**kwargs: Any,
) -> None:

# add only the last tracker, because it is the only new one
(
trackers_as_states,
trackers_as_actions,
) = self.featurizer.training_states_and_actions(training_trackers[-1:], domain)
self._add_states_to_lookup(trackers_as_states, trackers_as_actions, domain)

def _recall_states(self, states: List[Dict[Text, float]]) -> Optional[int]:

return self.lookup.get(self._create_feature_key(states))
Expand Down
37 changes: 0 additions & 37 deletions rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,43 +87,6 @@ def train(

raise NotImplementedError("Policy must have the capacity to train.")

def _training_data_for_continue_training(
self,
batch_size: int,
training_trackers: List[DialogueStateTracker],
domain: Domain,
) -> DialogueTrainingData:
"""Creates training_data for `continue_training` by
taking the new labelled example training_trackers[-1:]
and inserting it in batch_size-1 parts of the old training data,
"""
import numpy as np

num_samples = batch_size - 1
num_prev_examples = len(training_trackers) - 1

sampled_idx = np.random.choice(
range(num_prev_examples),
replace=False,
size=min(num_samples, num_prev_examples),
)
trackers = [training_trackers[i] for i in sampled_idx] + training_trackers[-1:]
return self.featurize_for_training(trackers, domain)

def continue_training(
self,
training_trackers: List[DialogueStateTracker],
domain: Domain,
**kwargs: Any,
) -> None:
"""Continues training an already trained policy.
This doesn't need to be supported by every policy. If it is supported,
the policy can be used for online training and the implementation for
the continued training should be put into this function."""

pass

def predict_action_probabilities(
self, tracker: DialogueStateTracker, domain: Domain
) -> List[float]:
Expand Down
6 changes: 1 addition & 5 deletions rasa/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ async def train(
output_path: Text,
interpreter: Optional["NaturalLanguageInterpreter"] = None,
endpoints: "AvailableEndpoints" = None,
dump_stories: bool = False,
policy_config: Optional[Union[Text, Dict]] = None,
exclusion_percentage: int = None,
additional_arguments: Optional[Dict] = None,
Expand Down Expand Up @@ -65,7 +64,7 @@ async def train(
training_resource, exclusion_percentage=exclusion_percentage, **data_load_args
)
agent.train(training_data, **additional_arguments)
agent.persist(output_path, dump_stories)
agent.persist(output_path)

return agent

Expand All @@ -77,7 +76,6 @@ async def train_comparison_models(
exclusion_percentages: Optional[List] = None,
policy_configs: Optional[List] = None,
runs: int = 1,
dump_stories: bool = False,
additional_arguments: Optional[Dict] = None,
):
"""Train multiple models for comparison of policies"""
Expand Down Expand Up @@ -115,7 +113,6 @@ async def train_comparison_models(
policy_config=policy_config,
exclusion_percentage=percentage,
additional_arguments=additional_arguments,
dump_stories=dump_stories,
),
model.model_fingerprint(file_importer),
)
Expand Down Expand Up @@ -154,7 +151,6 @@ async def do_compare_training(
exclusion_percentages=args.percentages,
policy_configs=args.config,
runs=args.runs,
dump_stories=args.dump_stories,
additional_arguments=additional_arguments,
),
get_no_of_stories(args.stories, args.domain),
Expand Down
1 change: 0 additions & 1 deletion rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,6 @@ async def train_agent_on_start(
model_directory,
_interpreter,
endpoints,
args.get("dump_stories"),
args.get("config")[0],
None,
additional_arguments,
Expand Down
Loading

0 comments on commit 8e15ae8

Please sign in to comment.