Skip to content

Commit

Permalink
refactor how state featurizer is prepared
Browse files Browse the repository at this point in the history
  • Loading branch information
dakshvar22 committed May 17, 2021
1 parent 6dec8db commit 4b06e36
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
33 changes: 25 additions & 8 deletions rasa/core/featurizers/tracker_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,31 @@ def training_states_and_actions(
)
return trackers_as_states, trackers_as_actions

def prepare_for_training(
self,
domain: Domain,
interpreter: NaturalLanguageInterpreter,
bilou_tagging: bool,
) -> None:
"""Makes sure that the featurizer is ready to be called during training.
State featurizer needs to build its vocabulary from the domain
for it to be ready to be used during training.
Args:
domain: Domain of the assistant.
interpreter: NLU Interpreter for featurizing states.
bilou_tagging: Whether to conside bilou tagging.
"""
if self.state_featurizer is None:
raise ValueError(
f"Instance variable 'state_featurizer' is not set. "
f"During initialization set 'state_featurizer' to an instance of "
f"'{SingleStateFeaturizer.__class__.__name__}' class "
f"to get numerical features for trackers."
)
self.state_featurizer.prepare_for_training(domain, interpreter, bilou_tagging)

def featurize_trackers(
self,
trackers: List[DialogueStateTracker],
Expand Down Expand Up @@ -216,14 +241,6 @@ def featurize_trackers(
containing entity tag ids for text user inputs otherwise empty dict
for all dialogue turns in all training trackers
"""
if self.state_featurizer is None:
raise ValueError(
f"Instance variable 'state_featurizer' is not set. "
f"During initialization set 'state_featurizer' to an instance of "
f"'{SingleStateFeaturizer.__class__.__name__}' class "
f"to get numerical features for trackers."
)

(
trackers_as_states,
trackers_as_actions,
Expand Down
2 changes: 1 addition & 1 deletion rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _prepare_for_training(
Returns:
Featurized data to be fed to the model and corresponding label ids.
"""
self.featurizer.state_featurizer.prepare_for_training(
self.featurizer.prepare_for_training(
domain, interpreter, bilou_tagging=self.config[BILOU_FLAG]
)

Expand Down

0 comments on commit 4b06e36

Please sign in to comment.