Skip to content

Commit

Permalink
Merge pull request #5200 from RasaHQ/fix-check-test-features
Browse files Browse the repository at this point in the history
Add check for text features
  • Loading branch information
tabergma authored Feb 10, 2020
2 parents a3549d1 + cc85153 commit 7ae4f4f
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 0 deletions.
1 change: 1 addition & 0 deletions changelog/5199.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If no text features are present in ``EmbeddingIntentClassifier`` return the intent ``None``.
5 changes: 5 additions & 0 deletions data/test/many_intents.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
- you are an idiot
- You lack understanding.

## intent:grett
- Hello
- Hi
- Welcome

## intent:thank
- Thanks
- Thank you
Expand Down
16 changes: 16 additions & 0 deletions rasa/nlu/classifiers/embedding_intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,16 @@ def _calculate_message_sim(
# transform sim to python list for JSON serializing
return label_ids, message_sim.tolist()

@staticmethod
def _text_features_present(session_data: SessionDataType) -> bool:
return np.array(
[
f.nnz != 0 if isinstance(f, scipy.sparse.spmatrix) else f.any()
for features in session_data["text_features"]
for f in features
]
).any()

def predict_label(
self, message: "Message"
) -> Tuple[Dict[Text, Any], List[Dict[Text, Any]]]:
Expand All @@ -835,6 +845,12 @@ def predict_label(

# create session data from message and convert it into a batch of 1
session_data = self._create_session_data([message])

# if no text-features are present (e.g. incoming message is not in the
# vocab), do not predict a random intent
if not self._text_features_present(session_data):
return label, label_ranking

batch = train_utils.prepare_batch(
session_data, tuple_sizes=self.batch_tuple_sizes
)
Expand Down
26 changes: 26 additions & 0 deletions tests/nlu/classifiers/test_embedding_intent_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,29 @@ async def test_margin_loss_is_not_normalized(

# make sure top ranking is reflected in intent prediction
assert parse_data.get("intent") == intent_ranking[0]


@pytest.mark.parametrize(
"session_data, expected",
[
(
{
"text_features": [
np.array(
[
np.random.rand(5, 14),
np.random.rand(2, 14),
np.random.rand(3, 14),
]
)
]
},
True,
),
({"text_features": [np.array([0, 0, 0])]}, False),
({"text_features": [scipy.sparse.csr_matrix([0, 0, 0])]}, False),
({"text_features": [scipy.sparse.csr_matrix([0, 31, 0])]}, True),
],
)
def test_text_features_present(session_data, expected):
assert EmbeddingIntentClassifier._text_features_present(session_data) == expected

0 comments on commit 7ae4f4f

Please sign in to comment.