diff --git a/changelog/5673.improvement.md b/changelog/5673.improvement.md new file mode 100644 index 000000000000..1afa952b6674 --- /dev/null +++ b/changelog/5673.improvement.md @@ -0,0 +1,8 @@ +Expose diagnostic data for action and NLU predictions. + +Add `diagnostic_data` field to the [Message](./reference/rasa/shared/nlu/training_data/message.md#message-objects) +and [Prediction](./reference/rasa/core/policies/policy.md#policyprediction-objects) objects, which contain +information about attention weights and other intermediate results of the inference computation. +This information can be used for debugging and fine-tuning, e.g. with [RasaLit](https://github.com/RasaHQ/rasalit). + +For examples of how to access the diagnostic data, see [here](https://gist.github.com/JEM-Mosig/c6e15b81ee70561cb72e361aff310d7e). diff --git a/docs/docs/tuning-your-model.mdx b/docs/docs/tuning-your-model.mdx index 63dcde2ad299..89283a95a459 100644 --- a/docs/docs/tuning-your-model.mdx +++ b/docs/docs/tuning-your-model.mdx @@ -293,7 +293,9 @@ Here is a summary of the available extractors and what they are best used for: |`MitieEntityExtractor` |MITIE |structured SVM |good for training custom entities | |`EntitySynonymMapper` |existing entities |N/A |maps known synonyms | -## Handling Class Imbalance +## Improving Performance + +### Handling Class Imbalance Classification algorithms often do not perform well if there is a large class imbalance, for example if you have a lot of training data for some intents and very little training data for others. @@ -312,6 +314,39 @@ pipeline: batch_strategy: sequence ``` +### Accessing Diagnostic Data + +To gain a better understanding of what your models do, you can access intermediate results of the prediction process. +To do this, you need to access the `diagnostic_data` field of the [Message](./reference/rasa/shared/nlu/training_data/message.md#message-objects) +and [Prediction](./reference/rasa/core/policies/policy.md#policyprediction-objects) objects, which contain +information about attention weights and other intermediate results of the inference computation. +You can use this information for debugging and fine-tuning, e.g. with [RasaLit](https://github.com/RasaHQ/rasalit). + +After you've [trained a model](.//command-line-interface.mdx#rasa-train), you can access diagnostic data for DIET, +given a processed message, like this: + +```python +nlu_diagnostic_data = message.as_dict()[DIAGNOSTIC_DATA] + +for component_name, diagnostic_data in nlu_diagnostic_data.items(): + attention_weights = diagnostic_data["attention_weights"] + print(f"attention_weights for {component_name}:") + print(attention_weights) + + text_transformed = diagnostic_data["text_transformed"] + print(f"\ntext_transformed for {component_name}:") + print(text_transformed) +``` + +And you can access diagnostic data for TED like this: + +```python +prediction = policy.predict_action_probabilities( + GREET_RULE, domain, RegexInterpreter() +) +print(f"{prediction.diagnostic_data.get('attention_weights')}") +``` + ## Configuring Tensorflow diff --git a/rasa/core/policies/policy.py b/rasa/core/policies/policy.py index 0294705baa33..261230dc0b46 100644 --- a/rasa/core/policies/policy.py +++ b/rasa/core/policies/policy.py @@ -236,6 +236,7 @@ def _prediction( events: Optional[List[Event]] = None, optional_events: Optional[List[Event]] = None, is_end_to_end_prediction: bool = False, + diagnostic_data: Optional[Dict[Text, Any]] = None, ) -> "PolicyPrediction": return PolicyPrediction( probabilities, @@ -244,6 +245,7 @@ def _prediction( events, optional_events, is_end_to_end_prediction, + diagnostic_data, ) def _metadata(self) -> Optional[Dict[Text, Any]]: @@ -400,6 +402,7 @@ def __init__( events: Optional[List[Event]] = None, optional_events: Optional[List[Event]] = None, is_end_to_end_prediction: bool = False, + diagnostic_data: Optional[Dict[Text, Any]] = None, ) -> None: """Creates a `PolicyPrediction`. @@ -417,6 +420,9 @@ def __init__( you return as they can potentially influence the conversation flow. is_end_to_end_prediction: `True` if the prediction used the text of the user message instead of the intent. + diagnostic_data: Intermediate results or other information that is not + necessary for Rasa to function, but intended for debugging and + fine-tuning purposes. """ self.probabilities = probabilities self.policy_name = policy_name @@ -424,6 +430,7 @@ def __init__( self.events = events or [] self.optional_events = optional_events or [] self.is_end_to_end_prediction = is_end_to_end_prediction + self.diagnostic_data = diagnostic_data or {} @staticmethod def for_action_name( @@ -466,6 +473,8 @@ def __eq__(self, other: Any) -> bool: and self.events == other.events and self.optional_events == other.events and self.is_end_to_end_prediction == other.is_end_to_end_prediction + # We do not compare `diagnostic_data`, because it has no effect on the + # action prediction. ) @property diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index 38b3088922e6..aa026e0f1c09 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -38,6 +38,7 @@ from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter from rasa.core.policies.policy import Policy, PolicyPrediction from rasa.core.constants import DEFAULT_POLICY_PRIORITY, DIALOGUE +from rasa.shared.constants import DIAGNOSTIC_DATA from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS, ACTION_LISTEN_NAME from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.generator import TrackerWithCachedStates @@ -50,6 +51,7 @@ Data, ) from rasa.utils.tensorflow.model_data_utils import convert_to_data_format +import rasa.utils.tensorflow.numpy from rasa.utils.tensorflow.constants import ( LABEL, IDS, @@ -632,6 +634,9 @@ def predict_action_probabilities( confidence.tolist(), is_end_to_end_prediction=is_e2e_prediction, optional_events=optional_events, + diagnostic_data=rasa.utils.tensorflow.numpy.values_to_numpy( + output.get(DIAGNOSTIC_DATA) + ), ) def _create_optional_event_for_entities( @@ -1050,14 +1055,23 @@ def _embed_dialogue( self, dialogue_in: tf.Tensor, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], - ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """Create dialogue level embedding and mask.""" + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Optional[tf.Tensor]]: + """Creates dialogue level embedding and mask. + + Args: + dialogue_in: The encoded dialogue. + tf_batch_data: Batch in model data format. + + Returns: + The dialogue embedding, the mask, and (for diagnostic purposes) + also the attention weights. + """ dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32) mask = self._compute_mask(dialogue_lengths) - dialogue_transformed = self._tf_layers[f"transformer.{DIALOGUE}"]( - dialogue_in, 1 - mask, self._training - ) + dialogue_transformed, attention_weights = self._tf_layers[ + f"transformer.{DIALOGUE}" + ](dialogue_in, 1 - mask, self._training) dialogue_transformed = tfa.activations.gelu(dialogue_transformed) if self.use_only_last_dialogue_turns: @@ -1069,7 +1083,7 @@ def _embed_dialogue( dialogue_embed = self._tf_layers[f"embed.{DIALOGUE}"](dialogue_transformed) - return dialogue_embed, mask, dialogue_transformed + return dialogue_embed, mask, dialogue_transformed, attention_weights def _encode_features_per_attribute( self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text @@ -1615,6 +1629,7 @@ def batch_loss( dialogue_embed, dialogue_mask, dialogue_transformer_output, + _, ) = self._embed_dialogue(dialogue_in, tf_batch_data) dialogue_mask = tf.squeeze(dialogue_mask, axis=-1) @@ -1686,6 +1701,7 @@ def batch_predict( dialogue_embed, dialogue_mask, dialogue_transformer_output, + attention_weights, ) = self._embed_dialogue(dialogue_in, tf_batch_data) dialogue_mask = tf.squeeze(dialogue_mask, axis=-1) @@ -1698,7 +1714,11 @@ def batch_predict( scores = self._tf_layers[f"loss.{LABEL}"].confidence_from_sim( sim_all, self.config[SIMILARITY_TYPE] ) - predictions = {"action_scores": scores, "similarities": sim_all} + predictions = { + "action_scores": scores, + "similarities": sim_all, + DIAGNOSTIC_DATA: {"attention_weights": attention_weights}, + } if ( self.config[ENTITY_RECOGNITION] diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index d1f26fec25fd..67eb8764516a 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -13,6 +13,8 @@ import rasa.shared.utils.io import rasa.utils.io as io_utils import rasa.nlu.utils.bilou_utils as bilou_utils +import rasa.utils.tensorflow.numpy +from rasa.shared.constants import DIAGNOSTIC_DATA from rasa.nlu.featurizers.featurizer import Featurizer from rasa.nlu.components import Component from rasa.nlu.classifiers.classifier import IntentClassifier @@ -914,7 +916,7 @@ def _predict_entities( return entities def process(self, message: Message, **kwargs: Any) -> None: - """Return the most likely label and its similarity to the input.""" + """Augments the message with intents, entities, and diagnostic data.""" out = self._predict(message) if self.component_config[INTENT_CLASSIFICATION]: @@ -928,12 +930,17 @@ def process(self, message: Message, **kwargs: Any) -> None: message.set(ENTITIES, entities, add_to_output=True) + if out and DIAGNOSTIC_DATA in out: + message.add_diagnostic_data( + self.unique_name, + rasa.utils.tensorflow.numpy.values_to_numpy(out.get(DIAGNOSTIC_DATA)), + ) + def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]: """Persist this model into the passed directory. Return the metadata necessary to load the model again. """ - if self.model is None: return {"file": None} @@ -1420,6 +1427,7 @@ def batch_loss( text_in, text_seq_ids, lm_mask_bool_text, + _, ) = self._create_sequence( tf_batch_data[TEXT][SEQUENCE], tf_batch_data[TEXT][SENTENCE], @@ -1569,7 +1577,7 @@ def batch_predict( mask = self._compute_mask(sequence_lengths) - text_transformed, _, _, _ = self._create_sequence( + text_transformed, _, _, _, attention_weights = self._create_sequence( tf_batch_data[TEXT][SEQUENCE], tf_batch_data[TEXT][SENTENCE], mask_sequence_text, @@ -1579,6 +1587,11 @@ def batch_predict( predictions: Dict[Text, tf.Tensor] = {} + predictions[DIAGNOSTIC_DATA] = { + "attention_weights": attention_weights, + "text_transformed": text_transformed, + } + if self.config[INTENT_CLASSIFICATION]: predictions.update( self._batch_predict_intents(sequence_lengths, text_transformed) diff --git a/rasa/nlu/components.py b/rasa/nlu/components.py index 73a31751ea14..0122aeae24c3 100644 --- a/rasa/nlu/components.py +++ b/rasa/nlu/components.py @@ -12,6 +12,7 @@ from rasa.shared.exceptions import InvalidConfigException from rasa.shared.nlu.training_data.training_data import TrainingData from rasa.shared.nlu.training_data.message import Message +from rasa.nlu.constants import COMPONENT_INDEX import rasa.shared.utils.io if typing.TYPE_CHECKING: @@ -388,26 +389,38 @@ class Component(metaclass=ComponentMetaclass): the pipeline to do intent classification. """ - # Component class name is used when integrating it in a - # pipeline. E.g. ``[ComponentA, ComponentB]`` - # will be a proper pipeline definition where ``ComponentA`` - # is the name of the first component of the pipeline. @property def name(self) -> Text: - """Access the class's property name from an instance.""" + """Returns the name of the component to be used in the model configuration. + Component class name is used when integrating it in a + pipeline. E.g. `[ComponentA, ComponentB]` + will be a proper pipeline definition where `ComponentA` + is the name of the first component of the pipeline. + """ return type(self).name - # Which components are required by this component. - # Listed components should appear before the component itself in the pipeline. + @property + def unique_name(self) -> Text: + """Gets a unique name for the component in the pipeline. + + The unique name can be used to distinguish components in + a pipeline, e.g. when the pipeline contains multiple + featurizers of the same type. + """ + index = self.component_config.get(COMPONENT_INDEX) + return self.name if index is None else f"component_{index}_{self.name}" + @classmethod def required_components(cls) -> List[Type["Component"]]: - """Specify which components need to be present in the pipeline. + """Specifies which components need to be present in the pipeline. + + Which components are required by this component. + Listed components should appear before the component itself in the pipeline. Returns: - The list of class names of required components. + The class names of the required components. """ - return [] # Defines the default configuration parameters of a component @@ -452,7 +465,7 @@ def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None: @classmethod def required_packages(cls) -> List[Text]: - """Specify which python packages need to be installed. + """Specifies which python packages need to be installed. E.g. ``["spacy"]``. More specifically, these should be importable python package names e.g. `sklearn` and not package @@ -464,7 +477,6 @@ def required_packages(cls) -> List[Text]: Returns: The list of required package names. """ - return [] @classmethod @@ -476,7 +488,7 @@ def load( cached_component: Optional["Component"] = None, **kwargs: Any, ) -> "Component": - """Load this component from file. + """Loads this component from file. After a component has been trained, it will be persisted by calling `persist`. When the pipeline gets loaded again, @@ -494,7 +506,6 @@ def load( Returns: the loaded component """ - if cached_component: return cached_component @@ -515,7 +526,6 @@ def create( Returns: The created component. """ - # Check language supporting language = config.language if not cls.can_handle_language(language): @@ -525,7 +535,7 @@ def create( return cls(component_config) def provide_context(self) -> Optional[Dict[Text, Any]]: - """Initialize this component for a new pipeline. + """Initializes this component for a new pipeline. This function will be called before the training is started and before the first message is processed using @@ -540,7 +550,6 @@ def provide_context(self) -> Optional[Dict[Text, Any]]: Returns: The updated component configuration. """ - pass def train( @@ -549,7 +558,7 @@ def train( config: Optional[RasaNLUModelConfig] = None, **kwargs: Any, ) -> None: - """Train this component. + """Trains this component. This is the components chance to train itself provided with the training data. The component can rely on @@ -561,16 +570,13 @@ def train( of components previous to this one. Args: - training_data: - The :class:`rasa.shared.nlu.training_data.training_data.TrainingData`. + training_data: The :class:`rasa.shared.nlu.training_data.training_data.TrainingData`. config: The model configuration parameters. - """ - pass def process(self, message: Message, **kwargs: Any) -> None: - """Process an incoming message. + """Processes an incoming message. This is the components chance to process an incoming message. The component can rely on @@ -583,13 +589,11 @@ def process(self, message: Message, **kwargs: Any) -> None: Args: message: The :class:`rasa.shared.nlu.training_data.message.Message` to process. - """ - pass def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]: - """Persist this component to disk for future loading. + """Persists this component to disk for future loading. Args: file_name: The file name of the model. @@ -598,7 +602,6 @@ def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]] Returns: An optional dictionary with any information about the stored model. """ - pass @classmethod @@ -619,10 +622,10 @@ def cache_key( Returns: A unique caching key. """ - return None def __getstate__(self) -> Any: + """Gets a copy of picklable parts of the component.""" d = self.__dict__.copy() # these properties should not be pickled if "partial_processing_context" in d: diff --git a/rasa/nlu/config.py b/rasa/nlu/config.py index afdb39af1e17..ba4653de101d 100644 --- a/rasa/nlu/config.py +++ b/rasa/nlu/config.py @@ -1,3 +1,4 @@ +import copy import logging import os from typing import Any, Dict, List, Optional, Text, Union @@ -10,6 +11,7 @@ DEFAULT_CONFIG_PATH, ) from rasa.shared.utils.io import json_to_string +from rasa.nlu.constants import COMPONENT_INDEX import rasa.utils.train_utils logger = logging.getLogger(__name__) @@ -56,19 +58,21 @@ def component_config_from_pipeline( pipeline: List[Dict[Text, Any]], defaults: Optional[Dict[Text, Any]] = None, ) -> Dict[Text, Any]: - """Get config of the component with the given index in the pipeline. + """Gets the configuration of the `index`th component. Args: - index: index the component in the pipeline - pipeline: a list of component configs in the NLU pipeline - defaults: default config of the component + index: Index of the component in the pipeline. + pipeline: Configurations of the components in the pipeline. + defaults: Default configuration. Returns: - config of the component + The `index`th component configuration, expanded + by the given defaults. """ try: - c = pipeline[index] - return rasa.utils.train_utils.override_defaults(defaults, c) + configuration = copy.deepcopy(pipeline[index]) + configuration[COMPONENT_INDEX] = index + return rasa.utils.train_utils.override_defaults(defaults, configuration) except IndexError: rasa.shared.utils.io.raise_warning( f"Tried to get configuration value for component " @@ -76,7 +80,9 @@ def component_config_from_pipeline( f"Returning `defaults`.", docs=DOCS_URL_PIPELINE, ) - return rasa.utils.train_utils.override_defaults(defaults, {}) + return rasa.utils.train_utils.override_defaults( + defaults, {COMPONENT_INDEX: index} + ) class RasaNLUModelConfig: @@ -157,6 +163,11 @@ def set_component_attr(self, index, **kwargs) -> None: docs=DOCS_URL_PIPELINE, ) - def override(self, config) -> None: + def override(self, config: Optional[Dict[Text, Any]] = None) -> None: + """Overrides default config with given values. + + Args: + config: New values for the configuration. + """ if config: self.__dict__.update(config) diff --git a/rasa/nlu/constants.py b/rasa/nlu/constants.py index 90dc50e07520..b62c1e9e6fe0 100644 --- a/rasa/nlu/constants.py +++ b/rasa/nlu/constants.py @@ -78,5 +78,7 @@ FEATURIZER_CLASS_ALIAS = "alias" NO_LENGTH_RESTRICTION = -1 + +COMPONENT_INDEX = "index" MIN_ADDITIONAL_REGEX_PATTERNS = 10 MIN_ADDITIONAL_CVF_VOCABULARY = 1000 diff --git a/rasa/nlu/featurizers/featurizer.py b/rasa/nlu/featurizers/featurizer.py index 3fdb63303ee7..9c3bfa9a5790 100644 --- a/rasa/nlu/featurizers/featurizer.py +++ b/rasa/nlu/featurizers/featurizer.py @@ -12,7 +12,9 @@ def __init__(self, component_config: Optional[Dict[Text, Any]] = None) -> None: component_config = {} # makes sure the alias name is set - component_config.setdefault(FEATURIZER_CLASS_ALIAS, self.name) + # Necessary for `unique_name` to be defined + self.component_config = component_config + component_config.setdefault(FEATURIZER_CLASS_ALIAS, self.unique_name) super().__init__(component_config) diff --git a/rasa/nlu/selectors/response_selector.py b/rasa/nlu/selectors/response_selector.py index ac78b6d3964a..2bc8aa792a23 100644 --- a/rasa/nlu/selectors/response_selector.py +++ b/rasa/nlu/selectors/response_selector.py @@ -6,6 +6,8 @@ from typing import Any, Dict, Optional, Text, Tuple, Union, List, Type +from rasa.shared.constants import DIAGNOSTIC_DATA +import rasa.utils.tensorflow.numpy from rasa.shared.nlu.training_data import util import rasa.shared.utils.io from rasa.shared.exceptions import InvalidConfigException @@ -381,7 +383,6 @@ def _resolve_intent_response_key( def process(self, message: Message, **kwargs: Any) -> None: """Return the most likely response, the associated intent_response_key and its similarity to the input.""" - out = self._predict(message) top_label, label_ranking = self._predict_label(out) @@ -439,6 +440,12 @@ def process(self, message: Message, **kwargs: Any) -> None: self._set_message_property(message, prediction_dict, selector_key) + if out and DIAGNOSTIC_DATA in out: + message.add_diagnostic_data( + self.unique_name, + rasa.utils.tensorflow.numpy.values_to_numpy(out.get(DIAGNOSTIC_DATA)), + ) + def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]: """Persist this model into the passed directory. @@ -623,7 +630,7 @@ def _create_all_labels(self) -> Tuple[tf.Tensor, tf.Tensor]: ) mask_label = self._compute_mask(sequence_lengths_label) - label_transformed, _, _, _ = self._create_sequence( + label_transformed, _, _, _, _ = self._create_sequence( self.tf_label_data[LABEL][SEQUENCE], self.tf_label_data[LABEL][SENTENCE], sequence_mask_label, @@ -653,6 +660,7 @@ def batch_loss( text_in, text_seq_ids, lm_mask_bool_text, + _, ) = self._create_sequence( tf_batch_data[TEXT][SEQUENCE], tf_batch_data[TEXT][SENTENCE], @@ -673,7 +681,7 @@ def batch_loss( ) mask_label = self._compute_mask(sequence_lengths_label) - label_transformed, _, _, _ = self._create_sequence( + label_transformed, _, _, _, _ = self._create_sequence( tf_batch_data[LABEL][SEQUENCE], tf_batch_data[LABEL][SENTENCE], sequence_mask_label, @@ -725,7 +733,7 @@ def batch_predict( ) mask_text = self._compute_mask(sequence_lengths_text) - text_transformed, _, _, _ = self._create_sequence( + text_transformed, _, _, _, attention_weights = self._create_sequence( tf_batch_data[TEXT][SEQUENCE], tf_batch_data[TEXT][SENTENCE], sequence_mask_text, @@ -735,6 +743,11 @@ def batch_predict( out = {} + out[DIAGNOSTIC_DATA] = { + "attention_weights": attention_weights, + "text_transformed": text_transformed, + } + if self.all_labels_embed is None: _, self.all_labels_embed = self._create_all_labels() diff --git a/rasa/shared/constants.py b/rasa/shared/constants.py index 1c618a37459a..33e0a7e19e47 100644 --- a/rasa/shared/constants.py +++ b/rasa/shared/constants.py @@ -77,3 +77,5 @@ DEFAULT_NLU_RESULTS_PATH = "nlu_comparison_results" DEFAULT_CORE_SUBDIRECTORY_NAME = "core" DEFAULT_NLU_SUBDIRECTORY_NAME = "nlu" + +DIAGNOSTIC_DATA = "diagnostic_data" diff --git a/rasa/shared/nlu/training_data/message.py b/rasa/shared/nlu/training_data/message.py index 5cf5596f8fa3..e2cb697124fb 100644 --- a/rasa/shared/nlu/training_data/message.py +++ b/rasa/shared/nlu/training_data/message.py @@ -21,6 +21,7 @@ ACTION_TEXT, ACTION_NAME, ) +from rasa.shared.constants import DIAGNOSTIC_DATA if typing.TYPE_CHECKING: from rasa.shared.nlu.training_data.features import Features @@ -51,7 +52,30 @@ def add_features(self, features: Optional["Features"]) -> None: if features is not None: self.features.append(features) - def set(self, prop, info, add_to_output=False) -> None: + def add_diagnostic_data(self, origin: Text, data: Dict[Text, Any]) -> None: + """Adds diagnostic data from the `origin` component. + + Args: + origin: Name of the component that created the data. + data: The diagnostic data. + """ + if origin in self.get(DIAGNOSTIC_DATA, {}): + rasa.shared.utils.io.raise_warning( + f"Please make sure every pipeline component has a distinct name. " + f"The name '{origin}' appears at least twice and diagnostic " + f"data will be overwritten." + ) + self.data.setdefault(DIAGNOSTIC_DATA, {}) + self.data[DIAGNOSTIC_DATA][origin] = data + + def set(self, prop: Text, info: Any, add_to_output: bool = False) -> None: + """Sets the message's property to the given value. + + Args: + prop: Name of the property to be set. + info: Value to be assigned to that property. + add_to_output: Decides whether to add `prop` to the `output_properties`. + """ self.data[prop] = info if add_to_output: self.output_properties.add(prop) diff --git a/rasa/utils/tensorflow/models.py b/rasa/utils/tensorflow/models.py index f4ff88562645..b29a3be35d82 100644 --- a/rasa/utils/tensorflow/models.py +++ b/rasa/utils/tensorflow/models.py @@ -570,9 +570,9 @@ def batch_to_model_data_format( """Convert input batch tensors into batch data format. Batch contains any number of batch data. The order is equal to the - key-value pairs in session data. As sparse data were converted into indices, - data, shape before, this methods converts them into sparse tensors. Dense data - is kept. + key-value pairs in session data. As sparse data were converted into (indices, + data, shape) before, this method converts them into sparse tensors. Dense + data is kept. """ batch_data = defaultdict(lambda: defaultdict(list)) @@ -775,7 +775,7 @@ def _prepare_transformer_layer( ) else: # create lambda so that it can be used later without the check - self._tf_layers[f"{prefix}.{name}"] = lambda x, mask, training: x + self._tf_layers[f"{prefix}.{name}"] = lambda x, mask, training: (x, None) def _prepare_dot_product_loss( self, name: Text, scale_loss: bool, prefix: Text = "loss" @@ -1032,7 +1032,13 @@ def _create_sequence( dense_dropout: bool = False, masked_lm_loss: bool = False, sequence_ids: bool = False, - ) -> Tuple[tf.Tensor, tf.Tensor, Optional[tf.Tensor], Optional[tf.Tensor]]: + ) -> Tuple[ + tf.Tensor, + tf.Tensor, + Optional[tf.Tensor], + Optional[tf.Tensor], + Optional[tf.Tensor], + ]: if sequence_ids: seq_ids = self._features_as_seq_ids(sequence_features, f"{name}_{SEQUENCE}") else: @@ -1057,7 +1063,7 @@ def _create_sequence( transformer_inputs = inputs lm_mask_bool = None - outputs = self._tf_layers[f"transformer.{name}"]( + outputs, attention_weights = self._tf_layers[f"transformer.{name}"]( transformer_inputs, 1 - mask, self._training ) @@ -1070,7 +1076,7 @@ def _create_sequence( # apply activation outputs = tfa.activations.gelu(outputs) - return outputs, inputs, seq_ids, lm_mask_bool + return outputs, inputs, seq_ids, lm_mask_bool, attention_weights @staticmethod def _compute_mask(sequence_lengths: tf.Tensor) -> tf.Tensor: diff --git a/rasa/utils/tensorflow/numpy.py b/rasa/utils/tensorflow/numpy.py new file mode 100644 index 000000000000..762faa59ab74 --- /dev/null +++ b/rasa/utils/tensorflow/numpy.py @@ -0,0 +1,25 @@ +from typing import Any, Dict, Optional +from tensorflow import Tensor + + +def values_to_numpy(data: Optional[Dict[Any, Any]]) -> Optional[Dict[Any, Any]]: + """Replaces all tensorflow-tensor values with their numpy versions. + + Args: + data: Any dictionary for which values should be converted. + + Returns: + A dictionary identical to `data` except that tensor values are + replaced by their corresponding numpy arrays. + """ + if not data: + return data + + return {key: _to_numpy_if_tensor(value) for key, value in data.items()} + + +def _to_numpy_if_tensor(value: Any) -> Any: + if isinstance(value, Tensor): + return value.numpy() + else: + return value diff --git a/rasa/utils/tensorflow/transformer.py b/rasa/utils/tensorflow/transformer.py index 222cfb89d16c..4c0332a08aff 100644 --- a/rasa/utils/tensorflow/transformer.py +++ b/rasa/utils/tensorflow/transformer.py @@ -453,7 +453,7 @@ def call( x: tf.Tensor, pad_mask: Optional[tf.Tensor] = None, training: Optional[Union[tf.Tensor, bool]] = None, - ) -> tf.Tensor: + ) -> Tuple[tf.Tensor, tf.Tensor]: """Apply transformer encoder layer. Arguments: @@ -469,7 +469,9 @@ def call( training = K.learning_phase() x_norm = self._layer_norm(x) # (batch_size, length, units) - attn_out, _ = self._mha(x_norm, x_norm, pad_mask=pad_mask, training=training) + attn_out, attn_weights = self._mha( + x_norm, x_norm, pad_mask=pad_mask, training=training + ) attn_out = self._dropout(attn_out, training=training) x += attn_out @@ -478,7 +480,8 @@ def call( ffn_out = layer(ffn_out, training=training) x += ffn_out - return x # (batch_size, length, units) + # (batch_size, length, units), (batch_size, num_heads, length, length) + return x, attn_weights class TransformerEncoder(tf.keras.layers.Layer): @@ -591,7 +594,7 @@ def call( x: tf.Tensor, pad_mask: Optional[tf.Tensor] = None, training: Optional[Union[tf.Tensor, bool]] = None, - ) -> tf.Tensor: + ) -> Tuple[tf.Tensor, tf.Tensor]: """Apply transformer encoder. Arguments: @@ -603,7 +606,6 @@ def call( Returns: Transformer encoder output with shape [batch_size, length, units] """ - # adding embedding and position encoding. x = self._embedding(x) # (batch_size, length, units) x *= tf.math.sqrt(tf.cast(self.units, tf.float32)) @@ -620,10 +622,16 @@ def call( 1.0, pad_mask + self._look_ahead_pad_mask(tf.shape(pad_mask)[-1]) ) # (batch_size, 1, length, length) + layer_attention_weights = [] + for layer in self._enc_layers: - x = layer(x, pad_mask=pad_mask, training=training) + x, attn_weights = layer(x, pad_mask=pad_mask, training=training) + layer_attention_weights.append(attn_weights) # if normalization is done in encoding layers, then it should also be done # on the output, since the output can grow very large, being the sum of # a whole stack of unnormalized layer outputs. - return self._layer_norm(x) # (batch_size, length, units) + x = self._layer_norm(x) # (batch_size, length, units) + + # (batch_size, length, units), (num_layers, batch_size, num_heads, length, length) + return x, tf.stack(layer_attention_weights) diff --git a/tests/conftest.py b/tests/conftest.py index 7af0f8f27ba6..b17c782f9973 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,6 +45,7 @@ INCORRECT_NLU_DATA, SIMPLE_STORIES_FILE, ) +from rasa.shared.exceptions import RasaException DEFAULT_CONFIG_PATH = "rasa/cli/default_config.yml" @@ -384,6 +385,31 @@ def blank_config() -> RasaNLUModelConfig: return RasaNLUModelConfig({"language": "en", "pipeline": []}) +@pytest.fixture(scope="session") +async def trained_response_selector_bot(trained_async: Callable) -> Path: + zipped_model = await trained_async( + domain="examples/responseselectorbot/domain.yml", + config="examples/responseselectorbot/config.yml", + training_files=[ + "examples/responseselectorbot/data/rules.yml", + "examples/responseselectorbot/data/stories.yml", + "examples/responseselectorbot/data/nlu.yml", + ], + ) + + if not zipped_model: + raise RasaException("Model training for responseselectorbot failed.") + + return Path(zipped_model) + + +@pytest.fixture(scope="session") +async def response_selector_agent( + trained_response_selector_bot: Optional[Path], +) -> Agent: + return Agent.load_local_model(trained_response_selector_bot) + + def write_endpoint_config_to_yaml( path: Path, data: Dict[Text, Any], endpoints_filename: Text = "endpoints.yml" ) -> Path: diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 0d3c844b9ac0..194c789a7368 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -216,18 +216,3 @@ async def form_bot_agent(trained_async: Callable) -> Agent: ) return Agent.load_local_model(zipped_model, action_endpoint=endpoint) - - -@pytest.fixture(scope="session") -async def response_selector_agent(trained_async: Callable) -> Agent: - zipped_model = await trained_async( - domain="examples/responseselectorbot/domain.yml", - config="examples/responseselectorbot/config.yml", - training_files=[ - "examples/responseselectorbot/data/rules.yml", - "examples/responseselectorbot/data/stories.yml", - "examples/responseselectorbot/data/nlu.yml", - ], - ) - - return Agent.load_local_model(zipped_model) diff --git a/tests/core/policies/test_ted_policy.py b/tests/core/policies/test_ted_policy.py new file mode 100644 index 000000000000..ea790d422127 --- /dev/null +++ b/tests/core/policies/test_ted_policy.py @@ -0,0 +1,429 @@ +from pathlib import Path +from typing import Optional +from unittest.mock import Mock + +import numpy as np +import pytest +import tests.core.test_policies +from _pytest.monkeypatch import MonkeyPatch +from rasa.core.featurizers.single_state_featurizer import SingleStateFeaturizer +from rasa.core.featurizers.tracker_featurizers import ( + MaxHistoryTrackerFeaturizer, + TrackerFeaturizer, +) +from rasa.core.policies.policy import Policy +from rasa.core.policies.ted_policy import TEDPolicy +from rasa.shared.core.constants import ACTION_LISTEN_NAME +from rasa.shared.core.domain import Domain +from rasa.shared.core.events import ( + ActionExecuted, + UserUttered, +) +from rasa.shared.core.trackers import DialogueStateTracker +from rasa.shared.nlu.interpreter import RegexInterpreter +from rasa.train import train_core +from rasa.utils import train_utils +from rasa.utils.tensorflow.constants import ( + EVAL_NUM_EXAMPLES, + KEY_RELATIVE_ATTENTION, + LOSS_TYPE, + MAX_RELATIVE_POSITION, + RANKING_LENGTH, + SCALE_LOSS, + SIMILARITY_TYPE, + VALUE_RELATIVE_ATTENTION, +) +from tests.core.test_policies import PolicyTestCollection + +UTTER_GREET_ACTION = "utter_greet" +GREET_INTENT_NAME = "greet" +DOMAIN_YAML = f""" +intents: +- {GREET_INTENT_NAME} +actions: +- {UTTER_GREET_ACTION} +""" + + +def test_diagnostics(): + domain = Domain.from_yaml(DOMAIN_YAML) + policy = TEDPolicy() + GREET_RULE = DialogueStateTracker.from_events( + "greet rule", + evts=[ + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(UTTER_GREET_ACTION), + ActionExecuted(ACTION_LISTEN_NAME), + UserUttered(intent={"name": GREET_INTENT_NAME}), + ActionExecuted(ACTION_LISTEN_NAME), + ], + ) + policy.train([GREET_RULE], domain, RegexInterpreter()) + prediction = policy.predict_action_probabilities( + GREET_RULE, domain, RegexInterpreter() + ) + + assert prediction.diagnostic_data + assert "attention_weights" in prediction.diagnostic_data + assert isinstance(prediction.diagnostic_data.get("attention_weights"), np.ndarray) + + +class TestTEDPolicy(PolicyTestCollection): + def test_train_model_checkpointing(self, tmp_path: Path): + model_name = "core-checkpointed-model" + best_model_file = tmp_path / (model_name + ".tar.gz") + assert not best_model_file.exists() + + train_core( + domain="data/test_domains/default.yml", + stories="data/test_stories/stories_defaultdomain.md", + output=str(tmp_path), + fixed_model_name=model_name, + config="data/test_config/config_ted_policy_model_checkpointing.yml", + ) + + assert best_model_file.exists() + + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy(featurizer=featurizer, priority=priority) + + def test_similarity_type(self, trained_policy: TEDPolicy): + assert trained_policy.config[SIMILARITY_TYPE] == "inner" + + def test_ranking_length(self, trained_policy: TEDPolicy): + assert trained_policy.config[RANKING_LENGTH] == 10 + + def test_normalization( + self, + trained_policy: TEDPolicy, + tracker: DialogueStateTracker, + default_domain: Domain, + monkeypatch: MonkeyPatch, + ): + # first check the output is what we expect + prediction = trained_policy.predict_action_probabilities( + tracker, default_domain, RegexInterpreter() + ) + assert not prediction.is_end_to_end_prediction + # count number of non-zero confidences + assert ( + sum([confidence > 0 for confidence in prediction.probabilities]) + == trained_policy.config[RANKING_LENGTH] + ) + # check that the norm is still 1 + assert sum(prediction.probabilities) == pytest.approx(1) + + # also check our function is called + mock = Mock() + monkeypatch.setattr(train_utils, "normalize", mock.normalize) + trained_policy.predict_action_probabilities( + tracker, default_domain, RegexInterpreter() + ) + + mock.normalize.assert_called_once() + + async def test_gen_batch(self, trained_policy: TEDPolicy, default_domain: Domain): + training_trackers = await tests.core.test_policies.train_trackers( + default_domain, augmentation_factor=0 + ) + interpreter = RegexInterpreter() + training_data, label_ids, entity_tags = trained_policy.featurize_for_training( + training_trackers, default_domain, interpreter + ) + label_data, all_labels = trained_policy._create_label_data( + default_domain, interpreter + ) + model_data = trained_policy._create_model_data( + training_data, label_ids, entity_tags, all_labels + ) + batch_size = 2 + + # model data keys were sorted, so the order is alphabetical + ( + batch_action_name_mask, + batch_action_name_sentence_indices, + batch_action_name_sentence_data, + batch_action_name_sentence_shape, + batch_dialogue_length, + batch_entities_mask, + batch_entities_sentence_indices, + batch_entities_sentence_data, + batch_entities_sentence_shape, + batch_intent_mask, + batch_intent_sentence_indices, + batch_intent_sentence_data, + batch_intent_sentence_shape, + batch_label_ids, + batch_slots_mask, + batch_slots_sentence_indices, + batch_slots_sentence_data, + batch_slots_sentence_shape, + ) = next(model_data._gen_batch(batch_size=batch_size)) + + assert ( + batch_label_ids.shape[0] == batch_size + and batch_dialogue_length.shape[0] == batch_size + ) + # batch and dialogue dimensions are NOT combined for masks + assert ( + batch_slots_mask.shape[0] == batch_size + and batch_intent_mask.shape[0] == batch_size + and batch_entities_mask.shape[0] == batch_size + and batch_action_name_mask.shape[0] == batch_size + ) + # some features might be "fake" so there sequence is `0` + seq_len = max( + [ + batch_intent_sentence_shape[1], + batch_action_name_sentence_shape[1], + batch_entities_sentence_shape[1], + batch_slots_sentence_shape[1], + ] + ) + assert ( + batch_intent_sentence_shape[1] == seq_len + or batch_intent_sentence_shape[1] == 0 + ) + assert ( + batch_action_name_sentence_shape[1] == seq_len + or batch_action_name_sentence_shape[1] == 0 + ) + assert ( + batch_entities_sentence_shape[1] == seq_len + or batch_entities_sentence_shape[1] == 0 + ) + assert ( + batch_slots_sentence_shape[1] == seq_len + or batch_slots_sentence_shape[1] == 0 + ) + + ( + batch_action_name_mask, + batch_action_name_sentence_indices, + batch_action_name_sentence_data, + batch_action_name_sentence_shape, + batch_dialogue_length, + batch_entities_mask, + batch_entities_sentence_indices, + batch_entities_sentence_data, + batch_entities_sentence_shape, + batch_intent_mask, + batch_intent_sentence_indices, + batch_intent_sentence_data, + batch_intent_sentence_shape, + batch_label_ids, + batch_slots_mask, + batch_slots_sentence_indices, + batch_slots_sentence_data, + batch_slots_sentence_shape, + ) = next( + model_data._gen_batch( + batch_size=batch_size, batch_strategy="balanced", shuffle=True + ) + ) + + assert ( + batch_label_ids.shape[0] == batch_size + and batch_dialogue_length.shape[0] == batch_size + ) + # some features might be "fake" so there sequence is `0` + seq_len = max( + [ + batch_intent_sentence_shape[1], + batch_action_name_sentence_shape[1], + batch_entities_sentence_shape[1], + batch_slots_sentence_shape[1], + ] + ) + assert ( + batch_intent_sentence_shape[1] == seq_len + or batch_intent_sentence_shape[1] == 0 + ) + assert ( + batch_action_name_sentence_shape[1] == seq_len + or batch_action_name_sentence_shape[1] == 0 + ) + assert ( + batch_entities_sentence_shape[1] == seq_len + or batch_entities_sentence_shape[1] == 0 + ) + assert ( + batch_slots_sentence_shape[1] == seq_len + or batch_slots_sentence_shape[1] == 0 + ) + + +class TestTEDPolicyMargin(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, priority=priority, **{LOSS_TYPE: "margin"} + ) + + def test_similarity_type(self, trained_policy: TEDPolicy): + assert trained_policy.config[SIMILARITY_TYPE] == "cosine" + + def test_normalization( + self, + trained_policy: Policy, + tracker: DialogueStateTracker, + default_domain: Domain, + monkeypatch: MonkeyPatch, + ): + # Mock actual normalization method + mock = Mock() + monkeypatch.setattr(train_utils, "normalize", mock.normalize) + trained_policy.predict_action_probabilities( + tracker, default_domain, RegexInterpreter() + ) + + # function should not get called for margin loss_type + mock.normalize.assert_not_called() + + +class TestTEDPolicyWithEval(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, + priority=priority, + **{SCALE_LOSS: False, EVAL_NUM_EXAMPLES: 4}, + ) + + +class TestTEDPolicyNoNormalization(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, priority=priority, **{RANKING_LENGTH: 0} + ) + + def test_ranking_length(self, trained_policy: TEDPolicy): + assert trained_policy.config[RANKING_LENGTH] == 0 + + def test_normalization( + self, + trained_policy: Policy, + tracker: DialogueStateTracker, + default_domain: Domain, + monkeypatch: MonkeyPatch, + ): + # first check the output is what we expect + predicted_probabilities = trained_policy.predict_action_probabilities( + tracker, default_domain, RegexInterpreter() + ).probabilities + # there should be no normalization + assert all([confidence > 0 for confidence in predicted_probabilities]) + + # also check our function is not called + mock = Mock() + monkeypatch.setattr(train_utils, "normalize", mock.normalize) + trained_policy.predict_action_probabilities( + tracker, default_domain, RegexInterpreter() + ) + + mock.normalize.assert_not_called() + + +class TestTEDPolicyLowRankingLength(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, priority=priority, **{RANKING_LENGTH: 3} + ) + + def test_ranking_length(self, trained_policy: TEDPolicy): + assert trained_policy.config[RANKING_LENGTH] == 3 + + +class TestTEDPolicyHighRankingLength(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, priority=priority, **{RANKING_LENGTH: 11} + ) + + def test_ranking_length(self, trained_policy: TEDPolicy): + assert trained_policy.config[RANKING_LENGTH] == 11 + + +class TestTEDPolicyWithStandardFeaturizer(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + # use standard featurizer from TEDPolicy, + # since it is using MaxHistoryTrackerFeaturizer + # if max_history is not specified + return TEDPolicy(priority=priority) + + def test_featurizer(self, trained_policy: Policy, tmp_path: Path): + assert isinstance(trained_policy.featurizer, MaxHistoryTrackerFeaturizer) + assert isinstance( + trained_policy.featurizer.state_featurizer, SingleStateFeaturizer + ) + trained_policy.persist(str(tmp_path)) + loaded = trained_policy.__class__.load(str(tmp_path)) + assert isinstance(loaded.featurizer, MaxHistoryTrackerFeaturizer) + assert isinstance(loaded.featurizer.state_featurizer, SingleStateFeaturizer) + + +class TestTEDPolicyWithMaxHistory(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + # use standard featurizer from TEDPolicy, + # since it is using MaxHistoryTrackerFeaturizer + # if max_history is specified + return TEDPolicy(priority=priority, max_history=self.max_history) + + def test_featurizer(self, trained_policy: Policy, tmp_path: Path): + assert isinstance(trained_policy.featurizer, MaxHistoryTrackerFeaturizer) + assert trained_policy.featurizer.max_history == self.max_history + assert isinstance( + trained_policy.featurizer.state_featurizer, SingleStateFeaturizer + ) + trained_policy.persist(str(tmp_path)) + loaded = trained_policy.__class__.load(str(tmp_path)) + assert isinstance(loaded.featurizer, MaxHistoryTrackerFeaturizer) + assert loaded.featurizer.max_history == self.max_history + assert isinstance(loaded.featurizer.state_featurizer, SingleStateFeaturizer) + + +class TestTEDPolicyWithRelativeAttention(TestTEDPolicy): + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, + priority=priority, + **{ + KEY_RELATIVE_ATTENTION: True, + VALUE_RELATIVE_ATTENTION: True, + MAX_RELATIVE_POSITION: 5, + }, + ) + + +class TestTEDPolicyWithRelativeAttentionMaxHistoryOne(TestTEDPolicy): + + max_history = 1 + + def create_policy( + self, featurizer: Optional[TrackerFeaturizer], priority: int + ) -> Policy: + return TEDPolicy( + featurizer=featurizer, + priority=priority, + **{ + KEY_RELATIVE_ATTENTION: True, + VALUE_RELATIVE_ATTENTION: True, + MAX_RELATIVE_POSITION: 5, + }, + ) diff --git a/tests/core/test_policies.py b/tests/core/test_policies.py index 997eb5228352..22d9be98dbd3 100644 --- a/tests/core/test_policies.py +++ b/tests/core/test_policies.py @@ -1,10 +1,9 @@ from pathlib import Path from typing import Type, List, Text, Tuple, Optional, Any -from unittest.mock import Mock, patch +from unittest.mock import patch import numpy as np import pytest -from _pytest.monkeypatch import MonkeyPatch from rasa.core.channels import OutputChannel from rasa.core.exceptions import UnsupportedDialogueModelError @@ -55,18 +54,6 @@ from rasa.core.policies.sklearn_policy import SklearnPolicy from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.nlu.training_data.formats.markdown import INTENT -from rasa.utils.tensorflow.constants import ( - SIMILARITY_TYPE, - RANKING_LENGTH, - LOSS_TYPE, - SCALE_LOSS, - EVAL_NUM_EXAMPLES, - KEY_RELATIVE_ATTENTION, - VALUE_RELATIVE_ATTENTION, - MAX_RELATIVE_POSITION, -) -from rasa.train import train_core -from rasa.utils import train_utils from tests.core.conftest import ( DEFAULT_DOMAIN_PATH_WITH_MAPPING, DEFAULT_DOMAIN_PATH_WITH_SLOTS, @@ -372,365 +359,6 @@ def test_finetune_after_load( assert loaded_policy.model -class TestTEDPolicy(PolicyTestCollection): - def test_train_model_checkpointing(self, tmp_path: Path): - model_name = "core-checkpointed-model" - best_model_file = tmp_path / (model_name + ".tar.gz") - assert not best_model_file.exists() - - train_core( - domain="data/test_domains/default.yml", - stories="data/test_stories/stories_defaultdomain.md", - output=str(tmp_path), - fixed_model_name=model_name, - config="data/test_config/config_ted_policy_model_checkpointing.yml", - ) - - assert best_model_file.exists() - - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy(featurizer=featurizer, priority=priority) - - def test_similarity_type(self, trained_policy: TEDPolicy): - assert trained_policy.config[SIMILARITY_TYPE] == "inner" - - def test_ranking_length(self, trained_policy: TEDPolicy): - assert trained_policy.config[RANKING_LENGTH] == 10 - - def test_normalization( - self, - trained_policy: TEDPolicy, - tracker: DialogueStateTracker, - default_domain: Domain, - monkeypatch: MonkeyPatch, - ): - # first check the output is what we expect - prediction = trained_policy.predict_action_probabilities( - tracker, default_domain, RegexInterpreter() - ) - assert not prediction.is_end_to_end_prediction - # count number of non-zero confidences - assert ( - sum([confidence > 0 for confidence in prediction.probabilities]) - == trained_policy.config[RANKING_LENGTH] - ) - # check that the norm is still 1 - assert sum(prediction.probabilities) == pytest.approx(1) - - # also check our function is called - mock = Mock() - monkeypatch.setattr(train_utils, "normalize", mock.normalize) - trained_policy.predict_action_probabilities( - tracker, default_domain, RegexInterpreter() - ) - - mock.normalize.assert_called_once() - - async def test_gen_batch(self, trained_policy: TEDPolicy, default_domain: Domain): - training_trackers = await train_trackers(default_domain, augmentation_factor=0) - interpreter = RegexInterpreter() - training_data, label_ids, entity_tags = trained_policy.featurize_for_training( - training_trackers, default_domain, interpreter - ) - label_data, all_labels = trained_policy._create_label_data( - default_domain, interpreter - ) - model_data = trained_policy._create_model_data( - training_data, label_ids, entity_tags, all_labels - ) - batch_size = 2 - - # model data keys were sorted, so the order is alphabetical - ( - batch_action_name_mask, - batch_action_name_sentence_indices, - batch_action_name_sentence_data, - batch_action_name_sentence_shape, - batch_dialogue_length, - batch_entities_mask, - batch_entities_sentence_indices, - batch_entities_sentence_data, - batch_entities_sentence_shape, - batch_intent_mask, - batch_intent_sentence_indices, - batch_intent_sentence_data, - batch_intent_sentence_shape, - batch_label_ids, - batch_slots_mask, - batch_slots_sentence_indices, - batch_slots_sentence_data, - batch_slots_sentence_shape, - ) = next(model_data._gen_batch(batch_size=batch_size)) - - assert ( - batch_label_ids.shape[0] == batch_size - and batch_dialogue_length.shape[0] == batch_size - ) - # batch and dialogue dimensions are NOT combined for masks - assert ( - batch_slots_mask.shape[0] == batch_size - and batch_intent_mask.shape[0] == batch_size - and batch_entities_mask.shape[0] == batch_size - and batch_action_name_mask.shape[0] == batch_size - ) - # some features might be "fake" so there sequence is `0` - seq_len = max( - [ - batch_intent_sentence_shape[1], - batch_action_name_sentence_shape[1], - batch_entities_sentence_shape[1], - batch_slots_sentence_shape[1], - ] - ) - assert ( - batch_intent_sentence_shape[1] == seq_len - or batch_intent_sentence_shape[1] == 0 - ) - assert ( - batch_action_name_sentence_shape[1] == seq_len - or batch_action_name_sentence_shape[1] == 0 - ) - assert ( - batch_entities_sentence_shape[1] == seq_len - or batch_entities_sentence_shape[1] == 0 - ) - assert ( - batch_slots_sentence_shape[1] == seq_len - or batch_slots_sentence_shape[1] == 0 - ) - - ( - batch_action_name_mask, - batch_action_name_sentence_indices, - batch_action_name_sentence_data, - batch_action_name_sentence_shape, - batch_dialogue_length, - batch_entities_mask, - batch_entities_sentence_indices, - batch_entities_sentence_data, - batch_entities_sentence_shape, - batch_intent_mask, - batch_intent_sentence_indices, - batch_intent_sentence_data, - batch_intent_sentence_shape, - batch_label_ids, - batch_slots_mask, - batch_slots_sentence_indices, - batch_slots_sentence_data, - batch_slots_sentence_shape, - ) = next( - model_data._gen_batch( - batch_size=batch_size, batch_strategy="balanced", shuffle=True - ) - ) - - assert ( - batch_label_ids.shape[0] == batch_size - and batch_dialogue_length.shape[0] == batch_size - ) - # some features might be "fake" so there sequence is `0` - seq_len = max( - [ - batch_intent_sentence_shape[1], - batch_action_name_sentence_shape[1], - batch_entities_sentence_shape[1], - batch_slots_sentence_shape[1], - ] - ) - assert ( - batch_intent_sentence_shape[1] == seq_len - or batch_intent_sentence_shape[1] == 0 - ) - assert ( - batch_action_name_sentence_shape[1] == seq_len - or batch_action_name_sentence_shape[1] == 0 - ) - assert ( - batch_entities_sentence_shape[1] == seq_len - or batch_entities_sentence_shape[1] == 0 - ) - assert ( - batch_slots_sentence_shape[1] == seq_len - or batch_slots_sentence_shape[1] == 0 - ) - - -class TestTEDPolicyMargin(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, priority=priority, **{LOSS_TYPE: "margin"} - ) - - def test_similarity_type(self, trained_policy: TEDPolicy): - assert trained_policy.config[SIMILARITY_TYPE] == "cosine" - - def test_normalization( - self, - trained_policy: Policy, - tracker: DialogueStateTracker, - default_domain: Domain, - monkeypatch: MonkeyPatch, - ): - # Mock actual normalization method - mock = Mock() - monkeypatch.setattr(train_utils, "normalize", mock.normalize) - trained_policy.predict_action_probabilities( - tracker, default_domain, RegexInterpreter() - ) - - # function should not get called for margin loss_type - mock.normalize.assert_not_called() - - -class TestTEDPolicyWithEval(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, - priority=priority, - **{SCALE_LOSS: False, EVAL_NUM_EXAMPLES: 4}, - ) - - -class TestTEDPolicyNoNormalization(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, priority=priority, **{RANKING_LENGTH: 0} - ) - - def test_ranking_length(self, trained_policy: TEDPolicy): - assert trained_policy.config[RANKING_LENGTH] == 0 - - def test_normalization( - self, - trained_policy: Policy, - tracker: DialogueStateTracker, - default_domain: Domain, - monkeypatch: MonkeyPatch, - ): - # first check the output is what we expect - predicted_probabilities = trained_policy.predict_action_probabilities( - tracker, default_domain, RegexInterpreter() - ).probabilities - # there should be no normalization - assert all([confidence > 0 for confidence in predicted_probabilities]) - - # also check our function is not called - mock = Mock() - monkeypatch.setattr(train_utils, "normalize", mock.normalize) - trained_policy.predict_action_probabilities( - tracker, default_domain, RegexInterpreter() - ) - - mock.normalize.assert_not_called() - - -class TestTEDPolicyLowRankingLength(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, priority=priority, **{RANKING_LENGTH: 3} - ) - - def test_ranking_length(self, trained_policy: TEDPolicy): - assert trained_policy.config[RANKING_LENGTH] == 3 - - -class TestTEDPolicyHighRankingLength(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, priority=priority, **{RANKING_LENGTH: 11} - ) - - def test_ranking_length(self, trained_policy: TEDPolicy): - assert trained_policy.config[RANKING_LENGTH] == 11 - - -class TestTEDPolicyWithStandardFeaturizer(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - # use standard featurizer from TEDPolicy, - # since it is using MaxHistoryTrackerFeaturizer - # if max_history is not specified - return TEDPolicy(priority=priority) - - def test_featurizer(self, trained_policy: Policy, tmp_path: Path): - assert isinstance(trained_policy.featurizer, MaxHistoryTrackerFeaturizer) - assert isinstance( - trained_policy.featurizer.state_featurizer, SingleStateFeaturizer - ) - trained_policy.persist(str(tmp_path)) - loaded = trained_policy.__class__.load(str(tmp_path)) - assert isinstance(loaded.featurizer, MaxHistoryTrackerFeaturizer) - assert isinstance(loaded.featurizer.state_featurizer, SingleStateFeaturizer) - - -class TestTEDPolicyWithMaxHistory(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - # use standard featurizer from TEDPolicy, - # since it is using MaxHistoryTrackerFeaturizer - # if max_history is specified - return TEDPolicy(priority=priority, max_history=self.max_history) - - def test_featurizer(self, trained_policy: Policy, tmp_path: Path): - assert isinstance(trained_policy.featurizer, MaxHistoryTrackerFeaturizer) - assert trained_policy.featurizer.max_history == self.max_history - assert isinstance( - trained_policy.featurizer.state_featurizer, SingleStateFeaturizer - ) - trained_policy.persist(str(tmp_path)) - loaded = trained_policy.__class__.load(str(tmp_path)) - assert isinstance(loaded.featurizer, MaxHistoryTrackerFeaturizer) - assert loaded.featurizer.max_history == self.max_history - assert isinstance(loaded.featurizer.state_featurizer, SingleStateFeaturizer) - - -class TestTEDPolicyWithRelativeAttention(TestTEDPolicy): - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, - priority=priority, - **{ - KEY_RELATIVE_ATTENTION: True, - VALUE_RELATIVE_ATTENTION: True, - MAX_RELATIVE_POSITION: 5, - }, - ) - - -class TestTEDPolicyWithRelativeAttentionMaxHistoryOne(TestTEDPolicy): - - max_history = 1 - - def create_policy( - self, featurizer: Optional[TrackerFeaturizer], priority: int - ) -> Policy: - return TEDPolicy( - featurizer=featurizer, - priority=priority, - **{ - KEY_RELATIVE_ATTENTION: True, - VALUE_RELATIVE_ATTENTION: True, - MAX_RELATIVE_POSITION: 5, - }, - ) - - class TestMemoizationPolicy(PolicyTestCollection): def create_policy( self, featurizer: Optional[TrackerFeaturizer], priority: int diff --git a/tests/nlu/classifiers/test_diet_classifier.py b/tests/nlu/classifiers/test_diet_classifier.py index 90f20a61039e..456f5c76292c 100644 --- a/tests/nlu/classifiers/test_diet_classifier.py +++ b/tests/nlu/classifiers/test_diet_classifier.py @@ -5,6 +5,7 @@ from unittest.mock import Mock from typing import List, Text, Dict, Any +import rasa.model from rasa.shared.nlu.training_data.features import Features from rasa.nlu import train from rasa.nlu.classifiers import LABEL_RANKING_LENGTH @@ -35,8 +36,10 @@ from rasa.nlu.model import Interpreter from rasa.shared.nlu.training_data.message import Message from rasa.utils import train_utils +from rasa.shared.constants import DIAGNOSTIC_DATA from tests.conftest import DEFAULT_NLU_DATA from tests.nlu.conftest import DEFAULT_DATA_PATH +from rasa.core.agent import Agent def test_compute_default_label_features(): @@ -330,7 +333,7 @@ async def test_margin_loss_is_not_normalized( _config = RasaNLUModelConfig({"pipeline": pipeline}) (trained_model, _, persisted_path) = await train( _config, - path=tmpdir.strpath, + path=str(tmpdir), data="data/test/many_intents.md", component_builder=component_builder, ) @@ -503,3 +506,27 @@ async def test_train_persist_load_with_composite_entities( assert loaded.pipeline text = "I am looking for an italian restaurant" assert loaded.parse(text) == trained.parse(text) + + +async def test_process_gives_diagnostic_data(trained_nlu_moodbot_path: Text,): + """Tests if processing a message returns attention weights as numpy array.""" + with rasa.model.unpack_model(trained_nlu_moodbot_path) as unpacked_model_directory: + _, nlu_model_directory = rasa.model.get_model_subdirectories( + unpacked_model_directory + ) + interpreter = Interpreter.load(nlu_model_directory) + + message = Message(data={TEXT: "hello"}) + for component in interpreter.pipeline: + component.process(message) + + diagnostic_data = message.get(DIAGNOSTIC_DATA) + + # The last component is DIETClassifier, which should add attention weights + name = f"component_{len(interpreter.pipeline) - 1}_DIETClassifier" + assert isinstance(diagnostic_data, dict) + assert name in diagnostic_data + assert "attention_weights" in diagnostic_data[name] + assert isinstance(diagnostic_data[name].get("attention_weights"), np.ndarray) + assert "text_transformed" in diagnostic_data[name] + assert isinstance(diagnostic_data[name].get("text_transformed"), np.ndarray) diff --git a/tests/nlu/selectors/test_selectors.py b/tests/nlu/selectors/test_selectors.py index 560bf9f9b67c..6d5c4aabea9d 100644 --- a/tests/nlu/selectors/test_selectors.py +++ b/tests/nlu/selectors/test_selectors.py @@ -1,8 +1,10 @@ from pathlib import Path import pytest +import numpy as np from typing import List, Dict, Text, Any +import rasa.model from rasa.nlu import train from rasa.nlu.components import ComponentBuilder from rasa.shared.nlu.training_data import util @@ -18,6 +20,8 @@ EVAL_NUM_EXAMPLES, CHECKPOINT_MODEL, ) +from rasa.shared.nlu.constants import TEXT +from rasa.shared.constants import DIAGNOSTIC_DATA from rasa.nlu.selectors.response_selector import ResponseSelector from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data.training_data import TrainingData @@ -282,3 +286,32 @@ async def test_train_persist_load(component_builder: ComponentBuilder, tmpdir: P await _train_persist_load_with_different_settings( pipeline, component_builder, tmpdir, True ) + + +async def test_process_gives_diagnostic_data(trained_response_selector_bot: Path): + """Tests if processing a message returns attention weights as numpy array.""" + + with rasa.model.unpack_model( + trained_response_selector_bot + ) as unpacked_model_directory: + _, nlu_model_directory = rasa.model.get_model_subdirectories( + unpacked_model_directory + ) + interpreter = Interpreter.load(nlu_model_directory) + + message = Message(data={TEXT: "hello"}) + for component in interpreter.pipeline: + component.process(message) + + diagnostic_data = message.get(DIAGNOSTIC_DATA) + + # The last component is ResponseSelector, which should add diagnostic data + name = f"component_{len(interpreter.pipeline) - 1}_ResponseSelector" + assert isinstance(diagnostic_data, dict) + assert name in diagnostic_data + assert "text_transformed" in diagnostic_data[name] + assert isinstance(diagnostic_data[name].get("text_transformed"), np.ndarray) + # The `attention_weights` key should exist, regardless of there being a transformer + assert "attention_weights" in diagnostic_data[name] + # By default, ResponseSelector has `number_of_transformer_layers = 0` + assert diagnostic_data[name].get("attention_weights") is None diff --git a/tests/nlu/test_config.py b/tests/nlu/test_config.py index e11c9a473c6a..9e9e40251a2e 100644 --- a/tests/nlu/test_config.py +++ b/tests/nlu/test_config.py @@ -13,6 +13,7 @@ import rasa.shared.nlu.training_data.loading from rasa.nlu import components from rasa.nlu.components import ComponentBuilder +from rasa.nlu.constants import COMPONENT_INDEX from rasa.shared.nlu.constants import TRAINABLE_EXTRACTORS from rasa.nlu.model import Trainer from tests.nlu.utilities import write_file_config @@ -98,10 +99,14 @@ def test_set_attr_on_component(): _config.set_component_attr(idx_classifier, epochs=10) - assert _config.for_component(idx_tokenizer) == {"name": "SpacyTokenizer"} + assert _config.for_component(idx_tokenizer) == { + "name": "SpacyTokenizer", + COMPONENT_INDEX: idx_tokenizer, + } assert _config.for_component(idx_classifier) == { "name": "DIETClassifier", "epochs": 10, + COMPONENT_INDEX: idx_classifier, } diff --git a/tests/shared/nlu/training_data/test_message.py b/tests/shared/nlu/training_data/test_message.py index 25ad054668f9..cb26a0dc022c 100644 --- a/tests/shared/nlu/training_data/test_message.py +++ b/tests/shared/nlu/training_data/test_message.py @@ -266,3 +266,10 @@ def test_is_core_or_domain_message( message: Message, result: bool, ): assert result == message.is_core_or_domain_message() + + +def test_add_diagnostic_data_with_repeated_component_raises_warning(): + message = Message() + message.add_diagnostic_data("a", {}) + with pytest.warns(UserWarning): + message.add_diagnostic_data("a", {}) diff --git a/tests/utils/tensorflow/test_numpy.py b/tests/utils/tensorflow/test_numpy.py new file mode 100644 index 000000000000..4a0544d7ff32 --- /dev/null +++ b/tests/utils/tensorflow/test_numpy.py @@ -0,0 +1,27 @@ +import pytest +import tensorflow as tf +import numpy as np +import rasa.utils.tensorflow.numpy +from typing import Optional, Dict, Any + + +@pytest.mark.parametrize( + "value, expected_result", + [ + ({}, {}), + ({"a": 1}, {"a": 1}), + ({"a": tf.zeros((2, 3))}, {"a": np.zeros((2, 3))}), + ], +) +def test_values_to_numpy( + value: Optional[Dict[Any, Any]], expected_result: Optional[Dict[Any, Any]] +): + actual_result = rasa.utils.tensorflow.numpy.values_to_numpy(value) + actual_result_value_types = [ + type(value) for value in sorted(actual_result.values()) + ] + expected_result_value_types = [ + type(value) for value in sorted(actual_result.values()) + ] + assert actual_result_value_types == expected_result_value_types + np.testing.assert_equal(actual_result, expected_result)