diff --git a/rasa/core/featurizers/tracker_featurizers.py b/rasa/core/featurizers/tracker_featurizers.py index 66dfbf68ccf4..97d90397fa03 100644 --- a/rasa/core/featurizers/tracker_featurizers.py +++ b/rasa/core/featurizers/tracker_featurizers.py @@ -187,6 +187,31 @@ def training_states_and_actions( ) return trackers_as_states, trackers_as_actions + def prepare_for_training( + self, + domain: Domain, + interpreter: NaturalLanguageInterpreter, + bilou_tagging: bool, + ) -> None: + """Makes sure that the featurizer is ready to be called during training. + + State featurizer needs to build its vocabulary from the domain + for it to be ready to be used during training. + + Args: + domain: Domain of the assistant. + interpreter: NLU Interpreter for featurizing states. + bilou_tagging: Whether to conside bilou tagging. + """ + if self.state_featurizer is None: + raise ValueError( + f"Instance variable 'state_featurizer' is not set. " + f"During initialization set 'state_featurizer' to an instance of " + f"'{SingleStateFeaturizer.__class__.__name__}' class " + f"to get numerical features for trackers." + ) + self.state_featurizer.prepare_for_training(domain, interpreter, bilou_tagging) + def featurize_trackers( self, trackers: List[DialogueStateTracker], @@ -216,14 +241,6 @@ def featurize_trackers( containing entity tag ids for text user inputs otherwise empty dict for all dialogue turns in all training trackers """ - if self.state_featurizer is None: - raise ValueError( - f"Instance variable 'state_featurizer' is not set. " - f"During initialization set 'state_featurizer' to an instance of " - f"'{SingleStateFeaturizer.__class__.__name__}' class " - f"to get numerical features for trackers." - ) - ( trackers_as_states, trackers_as_actions, diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index ff9b06e4be2f..bca7c04fbec9 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -556,7 +556,7 @@ def _prepare_for_training( Returns: Featurized data to be fed to the model and corresponding label ids. """ - self.featurizer.state_featurizer.prepare_for_training( + self.featurizer.prepare_for_training( domain, interpreter, bilou_tagging=self.config[BILOU_FLAG] )