Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with stop words in DeterministicIntentParser #137

Merged
merged 3 commits into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file.
- Make the `WrongModelVersion` error message intelligible [#133](https://github.com/snipsco/snips-nlu-rs/pull/133)
- Fix error handling in Python wrapper [#134](https://github.com/snipsco/snips-nlu-rs/pull/134)
- Return an error when using unknown intents in whitelist or blacklist [#136](https://github.com/snipsco/snips-nlu-rs/pull/136)
- Fix issue with stop words in `DeterministicIntentParser` [#137](https://github.com/snipsco/snips-nlu-rs/pull/137)

## [0.64.2] - 2019-04-09
### Fixed
Expand Down
144 changes: 103 additions & 41 deletions src/intent_parser/deterministic_intent_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub struct DeterministicIntentParser {
entity_scopes: HashMap<IntentName, (Vec<BuiltinEntityKind>, Vec<EntityName>)>,
ignore_stop_words: bool,
shared_resources: Arc<SharedResources>,
stop_words_whitelist: HashMap<IntentName, HashSet<String>>,
}

impl DeterministicIntentParser {
Expand Down Expand Up @@ -97,6 +98,11 @@ impl DeterministicIntentParser {
slot_names_to_entities: model.slot_names_to_entities,
entity_scopes,
ignore_stop_words: model.config.ignore_stop_words,
stop_words_whitelist: model
.stop_words_whitelist
.into_iter()
.map(|(intent, whitelist)| (intent, whitelist.into_iter().collect()))
.collect(),
shared_resources,
})
}
Expand Down Expand Up @@ -171,13 +177,13 @@ impl IntentParser for DeterministicIntentParser {
}

impl DeterministicIntentParser {
#[allow(clippy::map_clone)]
fn parse_top_intents(
&self,
input: &str,
top_n: usize,
intents: Option<&[&str]>,
) -> Result<Vec<InternalParsingResult>> {
let cleaned_input = self.preprocess_text(input);
let mut results = vec![];

let intents_set: HashSet<&str> = intents
Expand Down Expand Up @@ -214,8 +220,10 @@ impl DeterministicIntentParser {

let (ranges_mapping, formatted_input) =
replace_entities(input, matched_entities, get_entity_placeholder);
let cleaned_formatted_input = self.preprocess_text(&*formatted_input);
self.regexes_per_intent
let cleaned_input = self.preprocess_text(input, &**intent);
let cleaned_formatted_input = self.preprocess_text(&*formatted_input, &**intent);
if let Some(matching_result_formatted) = self
.regexes_per_intent
.get(intent)
.ok_or_else(|| format_err!("No associated regexes for intent '{}'", intent))?
.iter()
Expand All @@ -231,7 +239,9 @@ impl DeterministicIntentParser {
self.get_matching_result(input, &*cleaned_input, regex, intent, None)
})
})
.map(|matching_result_formatted| results.push(matching_result_formatted));
{
results.push(matching_result_formatted);
}
}

let confidence_score = if results.is_empty() {
Expand All @@ -250,17 +260,13 @@ impl DeterministicIntentParser {
.collect())
}

fn preprocess_text(&self, string: &str) -> String {
fn preprocess_text(&self, string: &str, intent: &str) -> String {
let stop_words = self.get_intent_stop_words(intent);
let tokens = tokenize(string, NluUtilsLanguage::from_language(self.language));
let mut current_idx = 0;
let mut cleaned_string = "".to_string();
for mut token in tokens {
if self.ignore_stop_words
&& self
.shared_resources
.stop_words
.contains(&token.normalized_value())
{
if self.ignore_stop_words && stop_words.contains(&token.normalized_value()) {
token.value = (0..token.value.chars().count()).map(|_| " ").collect();
}
let prefix_length = token.char_range.start - current_idx;
Expand All @@ -274,6 +280,18 @@ impl DeterministicIntentParser {
cleaned_string
}

fn get_intent_stop_words(&self, intent: &str) -> HashSet<&String> {
self.stop_words_whitelist
.get(intent)
.map(|whitelist| {
self.shared_resources
.stop_words
.difference(whitelist)
.collect()
})
.unwrap_or_else(|| self.shared_resources.stop_words.iter().collect())
}

fn get_matching_result(
&self,
input: &str,
Expand Down Expand Up @@ -414,7 +432,7 @@ mod tests {
use crate::entity_parser::builtin_entity_parser::BuiltinEntityParser;
use crate::entity_parser::custom_entity_parser::CustomEntityParser;

fn test_configuration() -> DeterministicParserModel {
fn sample_model() -> DeterministicParserModel {
DeterministicParserModel {
language_code: "en".to_string(),
patterns: hashmap![
Expand All @@ -438,6 +456,10 @@ mod tests {
"dummy_intent_5".to_string() => vec![
r"^\s*Send\s*5\s*dollars\s*to\s*john\s*$".to_string(),
],
"dummy_intent_6".to_string() => vec![
r"^\s*search\s*$".to_string(),
r"^\s*search\s*(?P<group9>%SEARCH_OBJECT%)\s*$".to_string(),
],
],
group_names_to_slot_names: hashmap![
"group0".to_string() => "dummy_slot_name".to_string(),
Expand All @@ -449,6 +471,7 @@ mod tests {
"group6".to_string() => "dummy_slot_name4".to_string(),
"group7".to_string() => "dummy_slot_name2".to_string(),
"group8".to_string() => "dummy_slot_name5".to_string(),
"group9".to_string() => "dummy_slot_name6".to_string(),
],
slot_names_to_entities: hashmap![
"dummy_intent_1".to_string() => hashmap![
Expand All @@ -467,10 +490,14 @@ mod tests {
"dummy_slot_name5".to_string() => "snips/number".to_string(),
],
"dummy_intent_5".to_string() => hashmap![],
"dummy_intent_6".to_string() => hashmap![
"dummy_slot_name6".to_string() => "search_object".to_string(),
],
],
config: DeterministicParserConfig {
ignore_stop_words: true,
},
stop_words_whitelist: HashMap::new(),
}
}

Expand Down Expand Up @@ -528,8 +555,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intent = parser.parse(text, None).unwrap().intent;
Expand Down Expand Up @@ -564,8 +590,7 @@ mod tests {
.builtin_entity_parser(mocked_builtin_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intent = parser.parse(text, None).unwrap().intent;
Expand All @@ -580,7 +605,7 @@ mod tests {
}

#[test]
fn test_parse_intent_with_whitelist() {
fn test_parse_intent_with_intents_whitelist() {
// Given
let text = "this is a dummy_a query with another dummy_c";
let mocked_custom_entity_parser = MockedCustomEntityParser::from_iter(vec![(
Expand All @@ -604,8 +629,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intent = parser
Expand Down Expand Up @@ -643,8 +667,7 @@ mod tests {
.builtin_entity_parser(mocked_builtin_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intents = parser.get_intents(text).unwrap();
Expand All @@ -654,17 +677,22 @@ mod tests {
.iter()
.map(|res| res.confidence_score)
.collect::<Vec<_>>();
let expected_scores = vec![0.5, 0.5, 0.0, 0.0, 0.0, 0.0];
let expected_scores = vec![0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0];
let intent_names = intents
.into_iter()
.skip(2)
.map(|res| res.intent_name.unwrap_or("null".to_string()).to_string())
.map(|res| {
res.intent_name
.unwrap_or_else(|| "null".to_string())
.to_string()
})
.sorted()
.collect::<Vec<_>>();
let expected_intent_names = vec![
"dummy_intent_1".to_string(),
"dummy_intent_2".to_string(),
"dummy_intent_4".to_string(),
"dummy_intent_6".to_string(),
"null".to_string(),
];
assert_eq!(expected_scores, scores);
Expand Down Expand Up @@ -692,8 +720,7 @@ mod tests {
.builtin_entity_parser(mocked_builtin_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intent = parser.parse(text, None).unwrap().intent;
Expand Down Expand Up @@ -805,8 +832,7 @@ mod tests {
.custom_entity_parser(my_mocked_custom_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intent = parser.parse(text, None).unwrap().intent;
Expand Down Expand Up @@ -845,8 +871,7 @@ mod tests {
.builtin_entity_parser(mocked_builtin_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let parsing_result = parser.parse(text, None).unwrap();
Expand Down Expand Up @@ -906,8 +931,7 @@ mod tests {
.stop_words(stop_words)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let intent = parser.parse(text, None).unwrap().intent;
Expand Down Expand Up @@ -947,8 +971,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build(),
);
let parser =
DeterministicIntentParser::new(test_configuration(), shared_resources).unwrap();
let parser = DeterministicIntentParser::new(sample_model(), shared_resources).unwrap();

// When
let slots = parser.parse(text, None).unwrap().slots;
Expand Down Expand Up @@ -989,8 +1012,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build(),
);
let parser =
DeterministicIntentParser::new(test_configuration(), shared_resources).unwrap();
let parser = DeterministicIntentParser::new(sample_model(), shared_resources).unwrap();

// When
let slots = parser.parse(text, None).unwrap().slots;
Expand Down Expand Up @@ -1036,8 +1058,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let slots = parser.parse(text, None).unwrap().slots;
Expand Down Expand Up @@ -1078,8 +1099,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build(),
);
let parser =
DeterministicIntentParser::new(test_configuration(), shared_resources).unwrap();
let parser = DeterministicIntentParser::new(sample_model(), shared_resources).unwrap();

// When
let slots = parser.parse(text, None).unwrap().slots;
Expand All @@ -1094,6 +1114,49 @@ mod tests {
assert_eq!(slots, expected_slots);
}

#[test]
fn test_parse_slots_with_stop_word_entity_value() {
// Given
let text = "search this please";
let mocked_custom_entity_parser = MockedCustomEntityParser::from_iter(vec![(
text.to_string(),
vec![CustomEntity {
value: "this".to_string(),
resolved_value: "this".to_string(),
range: 7..11,
entity_identifier: "search_object".to_string(),
}],
)]);
let stop_words = vec!["this".to_string(), "that".to_string(), "please".to_string()]
.into_iter()
.collect();
let shared_resources = Arc::new(
SharedResourcesBuilder::default()
.custom_entity_parser(mocked_custom_entity_parser)
.stop_words(stop_words)
.build(),
);
let mut parser_model = sample_model();
parser_model.stop_words_whitelist = hashmap! {
"dummy_intent_6".to_string() => vec!["this".to_string()].into_iter().collect()
};
let parser = DeterministicIntentParser::new(parser_model, shared_resources).unwrap();

// When
let result = parser.parse(text, None).unwrap();

// Then
let expected_slots = vec![InternalSlot {
value: "this".to_string(),
char_range: 7..11,
entity: "search_object".to_string(),
slot_name: "dummy_slot_name6".to_string(),
}];
let expected_intent = Some("dummy_intent_6".to_string());
assert_eq!(expected_intent, result.intent.intent_name);
assert_eq!(expected_slots, result.slots);
}

#[test]
fn test_get_slots() {
// Given
Expand Down Expand Up @@ -1125,8 +1188,7 @@ mod tests {
.custom_entity_parser(mocked_custom_entity_parser)
.build();
let parser =
DeterministicIntentParser::new(test_configuration(), Arc::new(shared_resources))
.unwrap();
DeterministicIntentParser::new(sample_model(), Arc::new(shared_resources)).unwrap();

// When
let slots = parser.get_slots(text, "dummy_intent_3").unwrap();
Expand Down
2 changes: 2 additions & 0 deletions src/models/intent_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ pub struct DeterministicParserModel {
pub patterns: HashMap<IntentName, Vec<String>>,
pub group_names_to_slot_names: HashMap<String, SlotName>,
pub slot_names_to_entities: HashMap<IntentName, HashMap<SlotName, EntityName>>,
#[serde(default)]
pub stop_words_whitelist: HashMap<IntentName, Vec<String>>,
pub config: DeterministicParserConfig,
}

Expand Down