Skip to content

Commit

Permalink
Merge pull request #5673 from RasaHQ/johannes-73
Browse files Browse the repository at this point in the history
Attention weight logging
  • Loading branch information
rasabot authored Jan 26, 2021
2 parents e311c1f + b20a5e0 commit 8d7e71e
Show file tree
Hide file tree
Showing 24 changed files with 806 additions and 458 deletions.
8 changes: 8 additions & 0 deletions changelog/5673.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Expose diagnostic data for action and NLU predictions.

Add `diagnostic_data` field to the [Message](./reference/rasa/shared/nlu/training_data/message.md#message-objects)
and [Prediction](./reference/rasa/core/policies/policy.md#policyprediction-objects) objects, which contain
information about attention weights and other intermediate results of the inference computation.
This information can be used for debugging and fine-tuning, e.g. with [RasaLit](https://github.com/RasaHQ/rasalit).

For examples of how to access the diagnostic data, see [here](https://gist.github.com/JEM-Mosig/c6e15b81ee70561cb72e361aff310d7e).
37 changes: 36 additions & 1 deletion docs/docs/tuning-your-model.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ Here is a summary of the available extractors and what they are best used for:
|`MitieEntityExtractor` |MITIE |structured SVM |good for training custom entities |
|`EntitySynonymMapper` |existing entities |N/A |maps known synonyms |

## Handling Class Imbalance
## Improving Performance

### Handling Class Imbalance

Classification algorithms often do not perform well if there is a large class imbalance,
for example if you have a lot of training data for some intents and very little training data for others.
Expand All @@ -312,6 +314,39 @@ pipeline:
batch_strategy: sequence
```

### Accessing Diagnostic Data

To gain a better understanding of what your models do, you can access intermediate results of the prediction process.
To do this, you need to access the `diagnostic_data` field of the [Message](./reference/rasa/shared/nlu/training_data/message.md#message-objects)
and [Prediction](./reference/rasa/core/policies/policy.md#policyprediction-objects) objects, which contain
information about attention weights and other intermediate results of the inference computation.
You can use this information for debugging and fine-tuning, e.g. with [RasaLit](https://github.com/RasaHQ/rasalit).

After you've [trained a model](.//command-line-interface.mdx#rasa-train), you can access diagnostic data for DIET,
given a processed message, like this:

```python
nlu_diagnostic_data = message.as_dict()[DIAGNOSTIC_DATA]

for component_name, diagnostic_data in nlu_diagnostic_data.items():
attention_weights = diagnostic_data["attention_weights"]
print(f"attention_weights for {component_name}:")
print(attention_weights)

text_transformed = diagnostic_data["text_transformed"]
print(f"\ntext_transformed for {component_name}:")
print(text_transformed)
```

And you can access diagnostic data for TED like this:

```python
prediction = policy.predict_action_probabilities(
GREET_RULE, domain, RegexInterpreter()
)
print(f"{prediction.diagnostic_data.get('attention_weights')}")
```


## Configuring Tensorflow

Expand Down
9 changes: 9 additions & 0 deletions rasa/core/policies/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def _prediction(
events: Optional[List[Event]] = None,
optional_events: Optional[List[Event]] = None,
is_end_to_end_prediction: bool = False,
diagnostic_data: Optional[Dict[Text, Any]] = None,
) -> "PolicyPrediction":
return PolicyPrediction(
probabilities,
Expand All @@ -244,6 +245,7 @@ def _prediction(
events,
optional_events,
is_end_to_end_prediction,
diagnostic_data,
)

def _metadata(self) -> Optional[Dict[Text, Any]]:
Expand Down Expand Up @@ -400,6 +402,7 @@ def __init__(
events: Optional[List[Event]] = None,
optional_events: Optional[List[Event]] = None,
is_end_to_end_prediction: bool = False,
diagnostic_data: Optional[Dict[Text, Any]] = None,
) -> None:
"""Creates a `PolicyPrediction`.
Expand All @@ -417,13 +420,17 @@ def __init__(
you return as they can potentially influence the conversation flow.
is_end_to_end_prediction: `True` if the prediction used the text of the
user message instead of the intent.
diagnostic_data: Intermediate results or other information that is not
necessary for Rasa to function, but intended for debugging and
fine-tuning purposes.
"""
self.probabilities = probabilities
self.policy_name = policy_name
self.policy_priority = (policy_priority,)
self.events = events or []
self.optional_events = optional_events or []
self.is_end_to_end_prediction = is_end_to_end_prediction
self.diagnostic_data = diagnostic_data or {}

@staticmethod
def for_action_name(
Expand Down Expand Up @@ -466,6 +473,8 @@ def __eq__(self, other: Any) -> bool:
and self.events == other.events
and self.optional_events == other.events
and self.is_end_to_end_prediction == other.is_end_to_end_prediction
# We do not compare `diagnostic_data`, because it has no effect on the
# action prediction.
)

@property
Expand Down
34 changes: 27 additions & 7 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from rasa.shared.nlu.interpreter import NaturalLanguageInterpreter
from rasa.core.policies.policy import Policy, PolicyPrediction
from rasa.core.constants import DEFAULT_POLICY_PRIORITY, DIALOGUE
from rasa.shared.constants import DIAGNOSTIC_DATA
from rasa.shared.core.constants import ACTIVE_LOOP, SLOTS, ACTION_LISTEN_NAME
from rasa.shared.core.trackers import DialogueStateTracker
from rasa.shared.core.generator import TrackerWithCachedStates
Expand All @@ -50,6 +51,7 @@
Data,
)
from rasa.utils.tensorflow.model_data_utils import convert_to_data_format
import rasa.utils.tensorflow.numpy
from rasa.utils.tensorflow.constants import (
LABEL,
IDS,
Expand Down Expand Up @@ -632,6 +634,9 @@ def predict_action_probabilities(
confidence.tolist(),
is_end_to_end_prediction=is_e2e_prediction,
optional_events=optional_events,
diagnostic_data=rasa.utils.tensorflow.numpy.values_to_numpy(
output.get(DIAGNOSTIC_DATA)
),
)

def _create_optional_event_for_entities(
Expand Down Expand Up @@ -1050,14 +1055,23 @@ def _embed_dialogue(
self,
dialogue_in: tf.Tensor,
tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]],
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Create dialogue level embedding and mask."""
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Optional[tf.Tensor]]:
"""Creates dialogue level embedding and mask.
Args:
dialogue_in: The encoded dialogue.
tf_batch_data: Batch in model data format.
Returns:
The dialogue embedding, the mask, and (for diagnostic purposes)
also the attention weights.
"""
dialogue_lengths = tf.cast(tf_batch_data[DIALOGUE][LENGTH][0], tf.int32)
mask = self._compute_mask(dialogue_lengths)

dialogue_transformed = self._tf_layers[f"transformer.{DIALOGUE}"](
dialogue_in, 1 - mask, self._training
)
dialogue_transformed, attention_weights = self._tf_layers[
f"transformer.{DIALOGUE}"
](dialogue_in, 1 - mask, self._training)
dialogue_transformed = tfa.activations.gelu(dialogue_transformed)

if self.use_only_last_dialogue_turns:
Expand All @@ -1069,7 +1083,7 @@ def _embed_dialogue(

dialogue_embed = self._tf_layers[f"embed.{DIALOGUE}"](dialogue_transformed)

return dialogue_embed, mask, dialogue_transformed
return dialogue_embed, mask, dialogue_transformed, attention_weights

def _encode_features_per_attribute(
self, tf_batch_data: Dict[Text, Dict[Text, List[tf.Tensor]]], attribute: Text
Expand Down Expand Up @@ -1615,6 +1629,7 @@ def batch_loss(
dialogue_embed,
dialogue_mask,
dialogue_transformer_output,
_,
) = self._embed_dialogue(dialogue_in, tf_batch_data)
dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)

Expand Down Expand Up @@ -1686,6 +1701,7 @@ def batch_predict(
dialogue_embed,
dialogue_mask,
dialogue_transformer_output,
attention_weights,
) = self._embed_dialogue(dialogue_in, tf_batch_data)
dialogue_mask = tf.squeeze(dialogue_mask, axis=-1)

Expand All @@ -1698,7 +1714,11 @@ def batch_predict(
scores = self._tf_layers[f"loss.{LABEL}"].confidence_from_sim(
sim_all, self.config[SIMILARITY_TYPE]
)
predictions = {"action_scores": scores, "similarities": sim_all}
predictions = {
"action_scores": scores,
"similarities": sim_all,
DIAGNOSTIC_DATA: {"attention_weights": attention_weights},
}

if (
self.config[ENTITY_RECOGNITION]
Expand Down
19 changes: 16 additions & 3 deletions rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import rasa.shared.utils.io
import rasa.utils.io as io_utils
import rasa.nlu.utils.bilou_utils as bilou_utils
import rasa.utils.tensorflow.numpy
from rasa.shared.constants import DIAGNOSTIC_DATA
from rasa.nlu.featurizers.featurizer import Featurizer
from rasa.nlu.components import Component
from rasa.nlu.classifiers.classifier import IntentClassifier
Expand Down Expand Up @@ -914,7 +916,7 @@ def _predict_entities(
return entities

def process(self, message: Message, **kwargs: Any) -> None:
"""Return the most likely label and its similarity to the input."""
"""Augments the message with intents, entities, and diagnostic data."""
out = self._predict(message)

if self.component_config[INTENT_CLASSIFICATION]:
Expand All @@ -928,12 +930,17 @@ def process(self, message: Message, **kwargs: Any) -> None:

message.set(ENTITIES, entities, add_to_output=True)

if out and DIAGNOSTIC_DATA in out:
message.add_diagnostic_data(
self.unique_name,
rasa.utils.tensorflow.numpy.values_to_numpy(out.get(DIAGNOSTIC_DATA)),
)

def persist(self, file_name: Text, model_dir: Text) -> Dict[Text, Any]:
"""Persist this model into the passed directory.
Return the metadata necessary to load the model again.
"""

if self.model is None:
return {"file": None}

Expand Down Expand Up @@ -1420,6 +1427,7 @@ def batch_loss(
text_in,
text_seq_ids,
lm_mask_bool_text,
_,
) = self._create_sequence(
tf_batch_data[TEXT][SEQUENCE],
tf_batch_data[TEXT][SENTENCE],
Expand Down Expand Up @@ -1569,7 +1577,7 @@ def batch_predict(

mask = self._compute_mask(sequence_lengths)

text_transformed, _, _, _ = self._create_sequence(
text_transformed, _, _, _, attention_weights = self._create_sequence(
tf_batch_data[TEXT][SEQUENCE],
tf_batch_data[TEXT][SENTENCE],
mask_sequence_text,
Expand All @@ -1579,6 +1587,11 @@ def batch_predict(

predictions: Dict[Text, tf.Tensor] = {}

predictions[DIAGNOSTIC_DATA] = {
"attention_weights": attention_weights,
"text_transformed": text_transformed,
}

if self.config[INTENT_CLASSIFICATION]:
predictions.update(
self._batch_predict_intents(sequence_lengths, text_transformed)
Expand Down
Loading

0 comments on commit 8d7e71e

Please sign in to comment.