From 582987f01cabad9fa6ad2d50e46fdca7d2a05971 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Mon, 1 Mar 2021 16:31:36 +0100 Subject: [PATCH 01/16] load words from spelling dictionary --- build/README.md | 4 ++++ build/make_build_dir.py | 25 +++++++++++++++++++++-- nlprule/src/compile/impls.rs | 27 +++++++++++++------------ nlprule/src/compile/mod.rs | 39 ++++++++++++++++++++++++++---------- nlprule/src/tokenizer/tag.rs | 8 ++++---- nlprule/src/types.rs | 25 ++++++++++++++++++++++- 6 files changed, 97 insertions(+), 31 deletions(-) diff --git a/build/README.md b/build/README.md index c826acc..98d48a1 100644 --- a/build/README.md +++ b/build/README.md @@ -79,6 +79,8 @@ python build/make_build_dir.py \ --chunker_token_model=$HOME/Downloads/nlprule/en-token.bin \ --chunker_pos_model=$HOME/Downloads/nlprule/en-pos-maxent.bin \ --chunker_chunk_model=$HOME/Downloads/nlprule/en-chunker.bin \ + --spell_dict_path=$LT_PATH/org/languagetool/resource/en/hunspell/en_GB.dict \ + --spell_info_path=$LT_PATH/org/languagetool/resource/en/hunspell/en_GB.info \ --out_dir=data/en ``` @@ -92,6 +94,8 @@ python build/make_build_dir.py \ --lang_code=de \ --tag_dict_path=$HOME/Downloads/nlprule/german-pos-dict/src/main/resources/org/languagetool/resource/de/german.dict \ --tag_info_path=$HOME/Downloads/nlprule/german-pos-dict/src/main/resources/org/languagetool/resource/de/german.info \ + --spell_dict_path=$LT_PATH/org/languagetool/resource/de/hunspell/de_DE.dict \ + --spell_info_path=$LT_PATH/org/languagetool/resource/de/hunspell/de_DE.info \ --out_dir=data/de ``` diff --git a/build/make_build_dir.py b/build/make_build_dir.py index 04fbfae..3b287f2 100644 --- a/build/make_build_dir.py +++ b/build/make_build_dir.py @@ -59,7 +59,7 @@ def copy_lt_files(out_dir, lt_dir, lang_code): canonicalize(out_dir / xmlfile) -def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): +def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): # dump dictionary, see https://dev.languagetool.org/developing-a-tagger-dictionary os.system( f"java -cp {lt_dir / 'languagetool.jar'} org.languagetool.tools.DictionaryExporter " @@ -119,6 +119,16 @@ def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): type=lambda p: Path(p).absolute(), help="Path to the accompanying tagger dictionary .info file.", ) + parser.add_argument( + "--spell_dict_path", + type=lambda p: Path(p).absolute(), + help="Path to a spell dictionary .dict file.", + ) + parser.add_argument( + "--spell_info_path", + type=lambda p: Path(p).absolute(), + help="Path to the accompanying spell dictionary .info file.", + ) parser.add_argument( "--chunker_token_model", default=None, @@ -149,12 +159,23 @@ def dump_dictionary(out_path, lt_dir, tag_dict_path, tag_info_path): write_freqlist(open(args.out_dir / "common.txt", "w"), args.lang_code) copy_lt_files(args.out_dir, args.lt_dir, args.lang_code) - dump_dictionary( + + # tagger dictionary + dump_dict( args.out_dir / "tags" / "output.dump", args.lt_dir, args.tag_dict_path, args.tag_info_path, ) + + # spell dictionary + dump_dict( + args.out_dir / "spell.dump", + args.lt_dir, + args.spell_dict_path, + args.spell_info_path, + ) + if ( args.chunker_token_model is not None and args.chunker_pos_model is not None diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index c5226a2..2c943db 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -92,14 +92,14 @@ impl Tagger { pub(in crate::compile) fn from_dumps, S2: AsRef>( paths: &[S1], remove_paths: &[S2], - common_words: &HashSet, + freq_words: HashMap, lang_options: TaggerLangOptions, ) -> std::io::Result { let mut tags = DefaultHashMap::default(); let mut groups = DefaultHashMap::default(); let mut tag_store = HashSet::new(); - let mut word_store = HashSet::new(); + let mut word_store = HashMap::new(); // hardcoded special tags tag_store.insert(""); @@ -114,29 +114,30 @@ impl Tagger { let punct = "!\"#$%&\\'()*+,-./:;<=>?@[\\]^_`{|}~"; for i in 0..punct.len() { - word_store.insert(&punct[i..(i + 1)]); + word_store.insert(&punct[i..(i + 1)], 0); } - word_store.extend(common_words.iter().map(|x| x.as_str())); - for (word, inflection, tag) in lines.iter() { - word_store.insert(word); - word_store.insert(inflection); + word_store.insert(word, 0); + word_store.insert(inflection, 0); tag_store.insert(tag); } + // extend with freq words at the end to make sure we overwrite words which existed but have 0 frequency + word_store.extend(freq_words.iter().map(|(word, freq)| (word.as_str(), *freq))); + // word store ids should be consistent across runs - let mut word_store: Vec<_> = word_store.iter().collect(); - word_store.sort(); + let mut word_store: Vec<_> = word_store.into_iter().collect(); + word_store.sort_unstable(); - // tag store ids should be consistent across runs - let mut tag_store: Vec<_> = tag_store.iter().collect(); - tag_store.sort(); + // tag store ids should be consistent across runs + let mut tag_store: Vec<_> = tag_store.into_iter().collect(); + tag_store.sort_unstable(); let word_store: BiMap<_, _> = word_store .iter() .enumerate() - .map(|(i, x)| (x.to_string(), WordIdInt(i as u32))) + .map(|(i, (word, freq))| ((*word).to_owned(), WordIdInt::new(i as u32, *freq))) .collect(); let tag_store: BiMap<_, _> = tag_store .iter() diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index 95a8be6..f6e3ab1 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -5,7 +5,7 @@ use fs_err as fs; use std::{ hash::{Hash, Hasher}, - io::{self, BufReader, BufWriter}, + io::{self, BufRead, BufReader, BufWriter}, num::ParseIntError, path::{Path, PathBuf}, str::FromStr, @@ -15,7 +15,7 @@ use std::{ use crate::{ rules::Rules, tokenizer::{chunk::Chunker, multiword::MultiwordTagger, tag::Tagger, Tokenizer}, - types::DefaultHasher, + types::*, }; use log::info; @@ -35,12 +35,13 @@ struct BuildFilePaths { disambiguation_path: PathBuf, grammar_path: PathBuf, multiword_tag_path: PathBuf, - common_words_path: PathBuf, regex_cache_path: PathBuf, srx_path: PathBuf, + spell_path: PathBuf, } impl BuildFilePaths { + // this has to be kept in sync with the paths the builder in build/make_build_dir.py stores the resources at fn new>(build_dir: P) -> Self { let p = build_dir.as_ref(); BuildFilePaths { @@ -51,9 +52,9 @@ impl BuildFilePaths { disambiguation_path: p.join("disambiguation.xml"), grammar_path: p.join("grammar.xml"), multiword_tag_path: p.join("tags/multiwords.txt"), - common_words_path: p.join("common.txt"), regex_cache_path: p.join("regex_cache.bin"), srx_path: p.join("segment.srx"), + spell_path: p.join("spell.dump"), } } } @@ -96,13 +97,29 @@ pub fn compile( let lang_code = fs::read_to_string(paths.lang_code_path)?; info!( - "Reading common words from {}.", - paths.common_words_path.display() + "Reading spelling words with frequency from {}.", + paths.spell_path.display() ); - let common_words = fs::read_to_string(paths.common_words_path)? - .lines() - .map(|x| x.to_string()) - .collect(); + let mut freq_words = DefaultHashMap::new(); + let reader = BufReader::new(File::open(paths.spell_path)?); + + for line in reader.lines() { + match line? + .trim() + .split_whitespace() + .collect::>() + .as_slice() + { + [freq, word] => { + // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. + // we start from 1 because 0 is reserved for words we do not know the frequency of + let freq = 1 + freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; + assert!(freq < u8::MAX as usize); + freq_words.insert(word.to_string(), freq as u8); + } + _ => continue, + } + } let tokenizer_lang_options = utils::tokenizer_lang_options(&lang_code).ok_or_else(|| { Error::LanguageOptionsDoNotExist { @@ -124,7 +141,7 @@ pub fn compile( let tagger = Tagger::from_dumps( &paths.tag_paths, &paths.tag_remove_paths, - &common_words, + freq_words, tagger_lang_options, )?; diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 568aa6e..433fa05 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -52,7 +52,7 @@ impl From for TaggerFields { let key: Vec = word.as_bytes().iter().chain(once(&i)).copied().collect(); let pos_bytes = pos_id.0.to_be_bytes(); - let inflect_bytes = inflect_id.0.to_be_bytes(); + let inflect_bytes = inflect_id.raw_value().to_be_bytes(); let value = u64::from_be_bytes([ inflect_bytes[0], @@ -74,7 +74,7 @@ impl From for TaggerFields { let mut word_store_items: Vec<_> = tagger .word_store .iter() - .map(|(key, value)| (key.clone(), value.0 as u64)) + .map(|(key, value)| (key.clone(), value.raw_value() as u64)) .collect(); word_store_items.sort_by(|(a, _), (b, _)| a.cmp(b)); @@ -106,7 +106,7 @@ impl From for Tagger { .into_str_vec() .unwrap() .into_iter() - .map(|(key, value)| (key, WordIdInt(value as u32))) + .map(|(key, value)| (key, WordIdInt::from_raw_value(value as u32))) .collect(); let mut tags = DefaultHashMap::new(); @@ -120,7 +120,7 @@ impl From for Tagger { let word_id = *word_store.get_by_left(word).unwrap(); let value_bytes = value.to_be_bytes(); - let inflection_id = WordIdInt(u32::from_be_bytes([ + let inflection_id = WordIdInt::from_raw_value(u32::from_be_bytes([ value_bytes[0], value_bytes[1], value_bytes[2], diff --git a/nlprule/src/types.rs b/nlprule/src/types.rs index a49572b..7e0586c 100644 --- a/nlprule/src/types.rs +++ b/nlprule/src/types.rs @@ -15,7 +15,30 @@ pub(crate) type DefaultHasher = hash_map::DefaultHasher; #[derive(Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] #[serde(transparent)] -pub(crate) struct WordIdInt(pub u32); +pub(crate) struct WordIdInt(u32); + +impl WordIdInt { + pub fn new(index: u32, freq: u8) -> Self { + assert!(index < 2u32.pow(24)); + + let mut id = index << 8; + id |= freq as u32; + WordIdInt(id) + } + + pub fn freq(&self) -> u8 { + (self.0 & 255) as u8 + } + + pub fn raw_value(&self) -> u32 { + self.0 + } + + pub fn from_raw_value(id: u32) -> Self { + WordIdInt(id) + } +} + #[derive(Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] #[serde(transparent)] pub(crate) struct PosIdInt(pub u16); From b2ea80a4dc851c263286bf3cc6796cee7203bf9e Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Mon, 1 Mar 2021 19:09:13 +0100 Subject: [PATCH 02/16] add spellcheck module --- nlprule/Cargo.toml | 2 + nlprule/src/compile/impls.rs | 2 + nlprule/src/lib.rs | 1 + nlprule/src/rules.rs | 44 +++++++- nlprule/src/spellcheck/mod.rs | 198 ++++++++++++++++++++++++++++++++++ nlprule/src/tokenizer/tag.rs | 5 +- nlprule/src/types.rs | 4 +- 7 files changed, 249 insertions(+), 7 deletions(-) create mode 100644 nlprule/src/spellcheck/mod.rs diff --git a/nlprule/Cargo.toml b/nlprule/Cargo.toml index 0f49004..f397430 100644 --- a/nlprule/Cargo.toml +++ b/nlprule/Cargo.toml @@ -30,6 +30,8 @@ half = { version = "1.7", features = ["serde"] } srx = { version = "^0.1.2", features = ["serde"] } lazycell = "1" cfg-if = "1" +triple_accel = "0.3" +appendlist = "1.4" rayon-cond = "0.1" rayon = "1.5" diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 2c943db..c3607df 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -21,6 +21,7 @@ use crate::{ DisambiguationRule, MatchGraph, Rule, }, rules::{Rules, RulesLangOptions, RulesOptions}, + spellcheck, tokenizer::{ chunk, multiword::{MultiwordTagger, MultiwordTaggerFields}, @@ -359,6 +360,7 @@ impl Rules { Rules { rules, options: RulesOptions::default(), + spellchecker: None, } } } diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 0a3fd5e..79215a5 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -63,6 +63,7 @@ pub mod compile; mod filter; pub mod rule; pub mod rules; +mod spellcheck; pub mod tokenizer; pub mod types; pub(crate) mod utils; diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index b7b3203..424ad48 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -1,19 +1,47 @@ //! Sets of grammatical error correction rules. -use crate::types::*; -use crate::utils::parallelism::MaybeParallelRefIterator; use crate::{rule::id::Selector, tokenizer::Tokenizer}; use crate::{rule::Rule, Error}; +use crate::{spellcheck::SpellcheckOptions, utils::parallelism::MaybeParallelRefIterator}; +use crate::{spellcheck::Spellchecker, types::*}; use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ io::{BufReader, Read}, + ops::{Deref, DerefMut}, path::Path, }; /// Options for a rule set. #[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct RulesOptions {} +pub struct RulesOptions { + spellcheck: bool, + spellcheck_options: SpellcheckOptions, +} + +pub struct RulesOptionsGuard<'a> { + rules: &'a mut Rules, +} + +impl<'a> Deref for RulesOptionsGuard<'a> { + type Target = RulesOptions; + + fn deref(&self) -> &Self::Target { + &self.rules.options + } +} + +impl<'a> DerefMut for RulesOptionsGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.rules.options + } +} + +impl<'a> Drop for RulesOptionsGuard<'a> { + fn drop(&mut self) { + self.rules.ingest_options() + } +} /// Language-dependent options for a rule set. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -43,9 +71,19 @@ impl Default for RulesLangOptions { pub struct Rules { pub(crate) rules: Vec, pub(crate) options: RulesOptions, + pub(crate) spellchecker: Option, } impl Rules { + fn ingest_options(&mut self) { + if self.options.spellcheck && self.spellchecker.is_none() { + self.spellchecker = Some(Spellchecker::new( + &self.tagger, + self.options.spellcheck_options.clone(), + )); + } + } + /// Creates a new rule set from a path to a binary. /// /// # Errors diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs new file mode 100644 index 0000000..72a70c5 --- /dev/null +++ b/nlprule/src/spellcheck/mod.rs @@ -0,0 +1,198 @@ +use std::{ + cmp, + hash::{Hash, Hasher}, +}; + +use appendlist::AppendList; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use triple_accel::levenshtein; + +use crate::{tokenizer::tag::Tagger, types::*}; + +fn hash(string: H) -> u64 { + let mut hasher = DefaultHasher::new(); + string.hash(&mut hasher); + hasher.finish() +} + +fn distance(a: usize, b: usize) -> usize { + if a > b { + a - b + } else { + b - a + } +} + +#[derive(Debug, Clone, Default, PartialEq, PartialOrd, Ord, Eq)] +pub struct Candidate<'a> { + pub distance: usize, + pub freq: u8, + pub term: &'a str, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpellcheckOptions { + max_dictionary_distance: usize, + prefix_length: usize, +} + +impl Default for SpellcheckOptions { + fn default() -> Self { + SpellcheckOptions { + max_dictionary_distance: 2, + prefix_length: 7, + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Spellchecker { + deletes: DefaultHashMap>, + max_length: usize, + options: SpellcheckOptions, +} + +lazy_static! { + static ref EMPTY_HASH: u64 = { + let empty: &[u8] = &[]; + hash(empty) + }; +} + +impl Spellchecker { + fn deletes_prefix(key: &[u8], options: &SpellcheckOptions) -> DefaultHashSet { + let mut out = DefaultHashSet::new(); + + if key.len() <= options.max_dictionary_distance { + out.insert(*EMPTY_HASH); + } + + Self::deletes_recurse(&key[..options.prefix_length], 0, &mut out, options); + out.insert(hash(key)); + + out + } + + fn deletes_recurse( + bytes: &[u8], + distance: usize, + out: &mut DefaultHashSet, + options: &SpellcheckOptions, + ) { + if bytes.len() > 1 { + for i in 0..bytes.len() { + let mut delete = bytes.to_vec(); + delete.remove(i); + + if distance + 1 < options.max_dictionary_distance { + Self::deletes_recurse(&delete, distance + 1, out, options); + } + + out.insert(hash(delete)); + } + } + } + + pub fn new(tagger: &Tagger, options: SpellcheckOptions) -> Self { + let mut deletes = DefaultHashMap::new(); + let mut max_length = 0; + + for (word, id) in tagger.word_store() { + for delete in Self::deletes_prefix(word.as_bytes(), &options) { + deletes.entry(delete).or_insert_with(Vec::new).push(*id); + } + max_length = cmp::max(word.len(), max_length); + } + + Spellchecker { + deletes, + max_length, + options, + } + } + + pub fn lookup<'t>(&self, token: &'t Token, max_distance: usize) -> Option>> { + if token.word.text.id().is_some() { + return None; + } + + let word = token.word.text.0.as_ref(); + let input_length = word.len(); + + if input_length - self.options.max_dictionary_distance > self.max_length { + return Some(Vec::new()); + } + + let mut candidates = Vec::new(); + + // deletes we've considered already + let mut known_deletes: DefaultHashSet = DefaultHashSet::new(); + // suggestions we've considered already + let mut known_suggestions: DefaultHashSet = DefaultHashSet::new(); + + let mut candidate_index = 0; + let deletes: AppendList> = AppendList::new(); + + let input_prefix_length = cmp::min(input_length, self.options.prefix_length); + deletes.push(word.as_bytes()[..input_prefix_length].to_vec()); + + while candidate_index < deletes.len() { + let candidate = deletes[candidate_index].as_slice(); + let candidate_length = candidate.len(); + + candidate_index += 1; + + let length_diff = input_prefix_length - candidate_length; + + if let Some(suggestions) = self.deletes.get(&hash(candidate)) { + for suggestion_id in suggestions { + let suggestion = token.tagger.str_for_word_id(suggestion_id); + let suggestion_length = suggestion.len(); + + if distance(suggestion_length, input_length) > max_distance + // suggestion must be for a different delete string, in same bin only because of hash collision + || suggestion_length < candidate_length + // in the same bin only because of hash collision, a valid suggestion is always longer than the delete + || suggestion_length == candidate_length + // we already added the suggestion + || known_suggestions.contains(suggestion_id) + { + continue; + } + + // SymSpell.cs covers some additional cases here where it is not necessary to compute the edit distance + // would have to be benchmarked if they are worth it considering `triple_accel` is presumably faster than + // the C# implementation of edit distance + + let distance = levenshtein(suggestion.as_bytes(), word.as_bytes()) as usize; + let freq = suggestion_id.freq(); + + candidates.push(Candidate { + term: suggestion, + distance, + freq, + }); + known_suggestions.insert(*suggestion_id); + } + } + + if length_diff < max_distance && candidate_length <= self.options.prefix_length { + for i in 0..candidate.len() { + let mut delete = candidate.to_owned(); + delete.remove(i); + + let delete_hash = hash(&delete); + + if !known_deletes.contains(&delete_hash) { + deletes.push(delete); + known_deletes.insert(delete_hash); + } + } + } + } + + candidates.sort(); + Some(candidates) + } +} diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 433fa05..87b6c32 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -209,18 +209,17 @@ impl Tagger { &self.tag_store } - #[allow(dead_code)] // used by compile module pub(crate) fn word_store(&self) -> &BiMap { &self.word_store } - fn str_for_word_id(&self, id: &WordIdInt) -> &str { + pub(crate) fn str_for_word_id(&self, id: &WordIdInt) -> &str { self.word_store .get_by_right(id) .expect("only valid word ids are created") } - fn str_for_pos_id(&self, id: &PosIdInt) -> &str { + pub(crate) fn str_for_pos_id(&self, id: &PosIdInt) -> &str { self.tag_store .get_by_right(id) .expect("only valid pos ids are created") diff --git a/nlprule/src/types.rs b/nlprule/src/types.rs index 7e0586c..db6dab9 100644 --- a/nlprule/src/types.rs +++ b/nlprule/src/types.rs @@ -13,7 +13,9 @@ pub(crate) type DefaultHashMap = HashMap; pub(crate) type DefaultHashSet = HashSet; pub(crate) type DefaultHasher = hash_map::DefaultHasher; -#[derive(Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] +#[derive( + Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd, Default, +)] #[serde(transparent)] pub(crate) struct WordIdInt(u32); From 7f6458a7a2d36cec646b3060b5a2fe7e31553685 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 2 Mar 2021 09:28:05 +0100 Subject: [PATCH 03/16] more clone derives, store Arc in Rules struct --- nlprule/src/bin/run.rs | 8 +- nlprule/src/bin/test.rs | 6 +- nlprule/src/compile/impls.rs | 4 +- nlprule/src/compile/mod.rs | 14 +++- nlprule/src/rule/engine/composition.rs | 34 ++++---- nlprule/src/rule/engine/mod.rs | 4 +- nlprule/src/rule/grammar.rs | 12 +-- nlprule/src/rule/mod.rs | 4 +- nlprule/src/rules.rs | 109 +++++++++++++++++-------- nlprule/src/tokenizer.rs | 7 +- nlprule/tests/tests.rs | 20 ++--- python/src/lib.rs | 63 ++++++-------- 12 files changed, 164 insertions(+), 121 deletions(-) diff --git a/nlprule/src/bin/run.rs b/nlprule/src/bin/run.rs index c6da086..6bc46fb 100644 --- a/nlprule/src/bin/run.rs +++ b/nlprule/src/bin/run.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use clap::Clap; use nlprule::{rules::Rules, tokenizer::Tokenizer}; @@ -18,11 +20,11 @@ fn main() { env_logger::init(); let opts = Opts::parse(); - let tokenizer = Tokenizer::new(opts.tokenizer).unwrap(); - let rules = Rules::new(opts.rules).unwrap(); + let tokenizer = Arc::new(Tokenizer::new(opts.tokenizer).unwrap()); + let rules = Rules::new(opts.rules, tokenizer.clone()).unwrap(); let tokens = tokenizer.pipe(&opts.text); println!("Tokens: {:#?}", tokens); - println!("Suggestions: {:#?}", rules.suggest(&opts.text, &tokenizer)); + println!("Suggestions: {:#?}", rules.suggest(&opts.text)); } diff --git a/nlprule/src/bin/test.rs b/nlprule/src/bin/test.rs index 3669a8e..2a81ed6 100644 --- a/nlprule/src/bin/test.rs +++ b/nlprule/src/bin/test.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use clap::Clap; use nlprule::{rules::Rules, tokenizer::Tokenizer}; @@ -19,8 +21,8 @@ fn main() { env_logger::init(); let opts = Opts::parse(); - let tokenizer = Tokenizer::new(opts.tokenizer).unwrap(); - let rules_container = Rules::new(opts.rules).unwrap(); + let tokenizer = Arc::new(Tokenizer::new(opts.tokenizer).unwrap()); + let rules_container = Rules::new(opts.rules, tokenizer.clone()).unwrap(); let rules = rules_container.rules(); println!("Runnable rules: {}", rules.len()); diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index c3607df..92b8791 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -8,6 +8,7 @@ use std::{ hash::{Hash, Hasher}, io::{self, BufRead, BufReader}, path::Path, + sync::Arc, }; use crate::{ @@ -21,7 +22,6 @@ use crate::{ DisambiguationRule, MatchGraph, Rule, }, rules::{Rules, RulesLangOptions, RulesOptions}, - spellcheck, tokenizer::{ chunk, multiword::{MultiwordTagger, MultiwordTaggerFields}, @@ -265,6 +265,7 @@ impl Rules { pub(in crate::compile) fn from_xml>( path: P, build_info: &mut BuildInfo, + tokenizer: Arc, options: RulesLangOptions, ) -> Self { let rules = super::parse_structure::read_rules(path); @@ -361,6 +362,7 @@ impl Rules { rules, options: RulesOptions::default(), spellchecker: None, + tokenizer, } } } diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index f6e3ab1..f2897de 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -82,6 +82,8 @@ pub enum Error { Unimplemented(String), #[error("error parsing to integer: {0}")] ParseError(#[from] ParseIntError), + #[error("nlprule error: {0}")] + NLPRuleError(#[from] crate::Error), #[error("unknown error")] Other(#[from] Box), } @@ -202,12 +204,16 @@ pub fn compile( srx::SRX::from_str(&fs::read_to_string(&paths.srx_path)?)?.language_rules(lang_code), tokenizer_lang_options, )?; - - bincode::serialize_into(&mut tokenizer_dest, &tokenizer)?; + tokenizer.to_writer(&mut tokenizer_dest)?; info!("Creating grammar rules."); - let rules = Rules::from_xml(&paths.grammar_path, &mut build_info, rules_lang_options); - bincode::serialize_into(&mut rules_dest, &rules)?; + let rules = Rules::from_xml( + &paths.grammar_path, + &mut build_info, + Arc::new(tokenizer), + rules_lang_options, + ); + rules.to_writer(&mut rules_dest)?; // we need to write the regex cache after building the rules, otherwise it isn't fully populated let f = BufWriter::new(File::create(&paths.regex_cache_path)?); diff --git a/nlprule/src/rule/engine/composition.rs b/nlprule/src/rule/engine/composition.rs index d05986e..378ea8a 100644 --- a/nlprule/src/rule/engine/composition.rs +++ b/nlprule/src/rule/engine/composition.rs @@ -4,7 +4,7 @@ use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use unicase::UniCase; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Matcher { pub matcher: either::Either, Regex>, pub negate: bool, @@ -68,7 +68,7 @@ impl Matcher { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct TextMatcher { pub(crate) matcher: Matcher, pub(crate) set: Option>, @@ -107,7 +107,7 @@ impl PosMatcher { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct WordDataMatcher { pub(crate) pos_matcher: Option, pub(crate) inflect_matcher: Option, @@ -141,7 +141,7 @@ impl WordDataMatcher { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Quantifier { pub min: usize, pub max: usize, @@ -153,7 +153,7 @@ pub trait Atomable: Send + Sync { } #[enum_dispatch(Atomable)] -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Atom { ChunkAtom(concrete::ChunkAtom), SpaceBeforeAtom(concrete::SpaceBeforeAtom), @@ -171,7 +171,7 @@ pub mod concrete { use super::{Atomable, MatchGraph, Matcher, TextMatcher, Token, WordDataMatcher}; use serde::{Deserialize, Serialize}; - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TextAtom { pub(crate) matcher: TextMatcher, } @@ -183,7 +183,7 @@ pub mod concrete { } } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChunkAtom { pub(crate) matcher: Matcher, } @@ -195,7 +195,7 @@ pub mod concrete { } } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SpaceBeforeAtom { pub(crate) value: bool, } @@ -206,7 +206,7 @@ pub mod concrete { } } - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct WordDataAtom { pub(crate) matcher: WordDataMatcher, pub(crate) case_sensitive: bool, @@ -222,7 +222,7 @@ pub mod concrete { } } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct TrueAtom {} impl Atomable for TrueAtom { @@ -231,7 +231,7 @@ impl Atomable for TrueAtom { } } -#[derive(Debug, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct FalseAtom {} impl Atomable for FalseAtom { @@ -240,7 +240,7 @@ impl Atomable for FalseAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct AndAtom { pub(crate) atoms: Vec, } @@ -253,7 +253,7 @@ impl Atomable for AndAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct OrAtom { pub(crate) atoms: Vec, } @@ -266,7 +266,7 @@ impl Atomable for OrAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct NotAtom { pub(crate) atom: Box, } @@ -277,7 +277,7 @@ impl Atomable for NotAtom { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct OffsetAtom { pub(crate) atom: Box, pub(crate) offset: isize, @@ -449,7 +449,7 @@ impl<'t> MatchGraph<'t> { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Part { pub atom: Atom, pub quantifier: Quantifier, @@ -458,7 +458,7 @@ pub struct Part { pub unify: Option, } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Composition { pub(crate) parts: Vec, pub(crate) id_to_idx: DefaultHashMap, diff --git a/nlprule/src/rule/engine/mod.rs b/nlprule/src/rule/engine/mod.rs index 75d933f..0667267 100644 --- a/nlprule/src/rule/engine/mod.rs +++ b/nlprule/src/rule/engine/mod.rs @@ -9,7 +9,7 @@ use composition::{Composition, Group, MatchGraph}; use self::composition::GraphId; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenEngine { pub(crate) composition: Composition, pub(crate) antipatterns: Vec, @@ -53,7 +53,7 @@ impl TokenEngine { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Engine { Token(TokenEngine), // regex with the `fancy_regex` backend is large on the stack diff --git a/nlprule/src/rule/grammar.rs b/nlprule/src/rule/grammar.rs index a289ae3..a7cc170 100644 --- a/nlprule/src/rule/grammar.rs +++ b/nlprule/src/rule/grammar.rs @@ -16,7 +16,7 @@ impl std::cmp::PartialEq for Suggestion { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Conversion { Nop, AllLower, @@ -38,7 +38,7 @@ impl Conversion { } /// An example associated with a [Rule][crate::rule::Rule]. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Example { pub(crate) text: String, pub(crate) suggestion: Option, @@ -58,7 +58,7 @@ impl Example { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct PosReplacer { pub(crate) matcher: PosMatcher, } @@ -98,7 +98,7 @@ impl PosReplacer { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Match { pub(crate) id: GraphId, pub(crate) conversion: Conversion, @@ -131,14 +131,14 @@ impl Match { } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum SynthesizerPart { Text(String), // Regex with the `fancy_regex` backend is large on the stack Match(Box), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Synthesizer { pub(crate) use_titlecase_adjust: bool, pub(crate) parts: Vec, diff --git a/nlprule/src/rule/mod.rs b/nlprule/src/rule/mod.rs index 42410c3..799d2f5 100644 --- a/nlprule/src/rule/mod.rs +++ b/nlprule/src/rule/mod.rs @@ -31,7 +31,7 @@ use self::{ /// A *Unification* makes an otherwise matching pattern invalid if no combination of its filters /// matches all tokens marked with "unify". /// Can also be negated. -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct Unification { pub(crate) mask: Vec>, pub(crate) filters: Vec>, @@ -375,7 +375,7 @@ impl<'a, 't> Iterator for Suggestions<'a, 't> { /// He dosn't know about it. /// /// ``` -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Rule { pub(crate) id: Index, pub(crate) engine: Engine, diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index 424ad48..7022217 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -2,14 +2,17 @@ use crate::{rule::id::Selector, tokenizer::Tokenizer}; use crate::{rule::Rule, Error}; -use crate::{spellcheck::SpellcheckOptions, utils::parallelism::MaybeParallelRefIterator}; -use crate::{spellcheck::Spellchecker, types::*}; +use crate::{ + spellcheck::SpellcheckOptions, spellcheck::Spellchecker, types::*, + utils::parallelism::MaybeParallelRefIterator, +}; use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ - io::{BufReader, Read}, + io::{BufReader, Read, Write}, ops::{Deref, DerefMut}, path::Path, + sync::Arc, }; /// Options for a rule set. @@ -19,6 +22,7 @@ pub struct RulesOptions { spellcheck_options: SpellcheckOptions, } +/// TODO pub struct RulesOptionsGuard<'a> { rules: &'a mut Rules, } @@ -66,40 +70,84 @@ impl Default for RulesLangOptions { } } -/// A set of grammatical error correction rules. #[derive(Serialize, Deserialize, Default)] +struct RulesFields { + pub(crate) rules: Vec, +} + +impl From for RulesFields { + fn from(rules: Rules) -> Self { + RulesFields { rules: rules.rules } + } +} + +/// A set of grammatical error correction rules. +#[derive(Clone, Default, Serialize, Deserialize)] pub struct Rules { pub(crate) rules: Vec, pub(crate) options: RulesOptions, pub(crate) spellchecker: Option, + pub(crate) tokenizer: Arc, } impl Rules { fn ingest_options(&mut self) { if self.options.spellcheck && self.spellchecker.is_none() { self.spellchecker = Some(Spellchecker::new( - &self.tagger, + &self.tokenizer.tagger(), self.options.spellcheck_options.clone(), )); } } - /// Creates a new rule set from a path to a binary. - /// - /// # Errors - /// - If the file can not be opened. - /// - If the file content can not be deserialized to a rules set. - pub fn new>(p: P) -> Result { - Rules::new_with_options(p, RulesOptions::default()) + /// TODO + pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { + // TODO: the .clone() here could be avoided + let fields: RulesFields = self.clone().into(); + writer.write_all(&bincode::serialize(&fields)?)?; + Ok(()) + } + + /// TODO + pub fn from_reader_with_options( + reader: R, + tokenizer: Arc, + options: RulesOptions, + ) -> Result { + let fields: RulesFields = bincode::deserialize_from(reader)?; + let mut rules = Rules { + rules: fields.rules, + options, + spellchecker: None, + tokenizer, + }; + rules.ingest_options(); + Ok(rules) } /// Creates a new rule set with options. See [new][Rules::new]. - pub fn new_with_options>(p: P, options: RulesOptions) -> Result { + pub fn new_with_options>( + p: P, + tokenizer: Arc, + options: RulesOptions, + ) -> Result { let reader = BufReader::new(File::open(p.as_ref())?); - let mut rules: Rules = bincode::deserialize_from(reader)?; - rules.options = options; - Ok(rules) + Self::from_reader_with_options(reader, tokenizer, options) + } + + /// Creates a new rules set from a reader. + pub fn from_reader(reader: R, tokenizer: Arc) -> Result { + Self::from_reader_with_options(reader, tokenizer, RulesOptions::default()) + } + + /// Creates a new rule set from a path to a binary. + /// + /// # Errors + /// - If the file can not be opened. + /// - If the file content can not be deserialized to a rules set. + pub fn new>(p: P, tokenizer: Arc) -> Result { + Self::new_with_options(p, tokenizer, RulesOptions::default()) } /// Gets the options of this rule set. @@ -112,11 +160,6 @@ impl Rules { &mut self.options } - /// Creates a new rules set from a reader. - pub fn from_reader(reader: R) -> Result { - Ok(bincode::deserialize_from(reader)?) - } - /// All rules ordered by priority. pub fn rules(&self) -> &[Rule] { &self.rules @@ -144,7 +187,7 @@ impl Rules { } /// Compute the suggestions for the given tokens by checking all rules. - pub fn apply(&self, tokens: &[Token], tokenizer: &Tokenizer) -> Vec { + pub fn apply(&self, tokens: &[Token]) -> Vec { if tokens.is_empty() { return Vec::new(); } @@ -157,7 +200,7 @@ impl Rules { .map(|(i, rule)| { let mut output = Vec::new(); - for suggestion in rule.apply(tokens, tokenizer) { + for suggestion in rule.apply(tokens, self.tokenizer.as_ref()) { output.push((i, suggestion)); } @@ -186,7 +229,7 @@ impl Rules { } /// Compute the suggestions for a text by checking all rules. - pub fn suggest(&self, text: &str, tokenizer: &Tokenizer) -> Vec { + pub fn suggest(&self, text: &str) -> Vec { if text.is_empty() { return Vec::new(); } @@ -195,19 +238,15 @@ impl Rules { let mut char_offset = 0; // get suggestions sentence by sentence - for tokens in tokenizer.pipe(text) { + for tokens in self.tokenizer.pipe(text) { if tokens.is_empty() { continue; } - suggestions.extend( - self.apply(&tokens, tokenizer) - .into_iter() - .map(|mut suggestion| { - suggestion.rshift(char_offset); - suggestion - }), - ); + suggestions.extend(self.apply(&tokens).into_iter().map(|mut suggestion| { + suggestion.rshift(char_offset); + suggestion + })); char_offset += tokens[0].sentence.chars().count(); } @@ -216,8 +255,8 @@ impl Rules { } /// Correct a text by first tokenizing, then finding all suggestions and choosing the first replacement of each suggestion. - pub fn correct(&self, text: &str, tokenizer: &Tokenizer) -> String { - let suggestions = self.suggest(text, tokenizer); + pub fn correct(&self, text: &str) -> String { + let suggestions = self.suggest(text); apply_suggestions(text, &suggestions) } } diff --git a/nlprule/src/tokenizer.rs b/nlprule/src/tokenizer.rs index 1887636..6fdcfc9 100644 --- a/nlprule/src/tokenizer.rs +++ b/nlprule/src/tokenizer.rs @@ -13,7 +13,7 @@ use crate::{ use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ - io::{BufReader, Read}, + io::{BufReader, Read, Write}, path::Path, sync::Arc, }; @@ -128,6 +128,11 @@ impl Tokenizer { Ok(bincode::deserialize_from(reader)?) } + /// TODO + pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { + Ok(bincode::serialize_into(writer, &self)?) + } + /// Gets all disambigation rules in the order they are applied. pub fn rules(&self) -> &[DisambiguationRule] { &self.rules diff --git a/nlprule/tests/tests.rs b/nlprule/tests/tests.rs index 270a17d..8f0cbdc 100644 --- a/nlprule/tests/tests.rs +++ b/nlprule/tests/tests.rs @@ -1,4 +1,4 @@ -use std::convert::TryInto; +use std::{convert::TryInto, sync::Arc}; use lazy_static::lazy_static; use nlprule::{rule::id::Category, Rules, Tokenizer}; @@ -8,8 +8,8 @@ const TOKENIZER_PATH: &str = "../storage/en_tokenizer.bin"; const RULES_PATH: &str = "../storage/en_rules.bin"; lazy_static! { - static ref TOKENIZER: Tokenizer = Tokenizer::new(TOKENIZER_PATH).unwrap(); - static ref RULES: Rules = Rules::new(RULES_PATH).unwrap(); + static ref TOKENIZER: Arc = Arc::new(Tokenizer::new(TOKENIZER_PATH).unwrap()); + static ref RULES: Rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); } #[test] @@ -25,12 +25,10 @@ fn can_tokenize_anything(text: String) -> bool { #[test] fn rules_can_be_disabled_enabled() { - let mut rules = Rules::new(RULES_PATH).unwrap(); + let mut rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); // enabled by default - assert!(!rules - .suggest("I can due his homework", &*TOKENIZER) - .is_empty()); + assert!(!rules.suggest("I can due his homework").is_empty()); rules .select_mut( @@ -41,17 +39,15 @@ fn rules_can_be_disabled_enabled() { .for_each(|x| x.disable()); // disabled now - assert!(rules - .suggest("I can due his homework", &*TOKENIZER) - .is_empty()); + assert!(rules.suggest("I can due his homework").is_empty()); // disabled by default - assert!(rules.suggest("I can not go", &*TOKENIZER).is_empty()); + assert!(rules.suggest("I can not go").is_empty()); rules .select_mut(&"typos/can_not".try_into().unwrap()) .for_each(|x| x.enable()); // enabled now - assert!(!rules.suggest("I can not go", &*TOKENIZER).is_empty()); + assert!(!rules.suggest("I can not go").is_empty()); } diff --git a/python/src/lib.rs b/python/src/lib.rs index 54084cb..13eafb0 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -309,11 +309,11 @@ impl From for PySuggestion { #[text_signature = "(path, sentence_splitter=None)"] #[derive(Default)] pub struct PyTokenizer { - tokenizer: Tokenizer, + tokenizer: Arc, } impl PyTokenizer { - fn tokenizer(&self) -> &Tokenizer { + fn tokenizer(&self) -> &Arc { &self.tokenizer } } @@ -327,7 +327,9 @@ impl PyTokenizer { let tokenizer: Tokenizer = bincode::deserialize_from(bytes) .map_err(|x| PyValueError::new_err(format!("{}", x)))?; - Ok(PyTokenizer { tokenizer }) + Ok(PyTokenizer { + tokenizer: Arc::new(tokenizer), + }) } #[new] @@ -339,7 +341,9 @@ impl PyTokenizer { Tokenizer::default() }; - Ok(PyTokenizer { tokenizer }) + Ok(PyTokenizer { + tokenizer: Arc::new(tokenizer), + }) } /// Get the tagger dictionary of this tokenizer. @@ -402,8 +406,8 @@ impl PyTokenizer { } } -impl From for PyTokenizer { - fn from(tokenizer: Tokenizer) -> Self { +impl From> for PyTokenizer { + fn from(tokenizer: Arc) -> Self { PyTokenizer { tokenizer } } } @@ -552,41 +556,39 @@ impl PyRule { #[text_signature = "(path, tokenizer, sentence_splitter=None)"] struct PyRules { rules: Arc>, - tokenizer: Py, } #[pymethods] impl PyRules { #[text_signature = "(code, tokenizer, sentence_splitter=None)"] #[staticmethod] - fn load(lang_code: &str, tokenizer: Py) -> PyResult { + fn load(lang_code: &str, tokenizer: &PyTokenizer) -> PyResult { let bytes = get_resource(lang_code, "rules.bin.gz")?; - let rules: Rules = bincode::deserialize_from(bytes) + let rules = Rules::from_reader(bytes, tokenizer.tokenizer().clone()) .map_err(|x| PyValueError::new_err(format!("{}", x)))?; Ok(PyRules { rules: Arc::from(RwLock::from(rules)), - tokenizer, }) } #[new] - fn new(py: Python, path: Option<&str>, tokenizer: Option>) -> PyResult { + fn new(path: Option<&str>, tokenizer: Option<&PyTokenizer>) -> PyResult { + let tokenizer = if let Some(tokenizer) = tokenizer { + tokenizer.tokenizer().clone() + } else { + PyTokenizer::default().tokenizer().clone() + }; + let rules = if let Some(path) = path { - Rules::new(path) + Rules::new(path, tokenizer) .map_err(|x| PyValueError::new_err(format!("error creating Rules: {}", x)))? } else { Rules::default() }; - let tokenizer = if let Some(tokenizer) = tokenizer { - tokenizer - } else { - Py::new(py, PyTokenizer::default())? - }; Ok(PyRules { rules: Arc::from(RwLock::from(rules)), - tokenizer, }) } @@ -628,12 +630,9 @@ impl PyRules { #[text_signature = "(sentence_or_sentences)"] fn suggest(&self, py: Python, sentence_or_sentences: PyObject) -> PyResult { text_guard(py, sentence_or_sentences, |sentence| { - let tokenizer = self.tokenizer.borrow(py); - let tokenizer = tokenizer.tokenizer(); - self.rules .read() - .suggest(&sentence, &tokenizer) + .suggest(&sentence) .into_iter() .map(|x| PyCell::new(py, PySuggestion::from(x))) .collect::>>() @@ -651,10 +650,7 @@ impl PyRules { #[text_signature = "(text_or_texts)"] fn correct(&self, py: Python, text_or_texts: PyObject) -> PyResult { text_guard(py, text_or_texts, |text| { - let tokenizer = self.tokenizer.borrow(py); - let tokenizer = tokenizer.tokenizer(); - - Ok(self.rules.read().correct(&text, tokenizer)) + Ok(self.rules.read().correct(&text)) }) } @@ -691,13 +687,11 @@ impl PyRules { pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { match state.extract::<&PyBytes>(py) { Ok(s) => { - let state: (Rules, Tokenizer) = - bincode::deserialize(s.as_bytes()).map_err(|_| { - PyValueError::new_err("deserializing state with `bincode` failed") - })?; + let rules: Rules = bincode::deserialize(s.as_bytes()).map_err(|_| { + PyValueError::new_err("deserializing state with `bincode` failed") + })?; // a roundtrip through pickle can not preserve references so we need to create a new Arc> - self.rules = Arc::from(RwLock::from(state.0)); - self.tokenizer = Py::new(py, PyTokenizer::from(state.1))?; + self.rules = Arc::new(RwLock::new(rules)); Ok(()) } Err(e) => Err(e), @@ -705,13 +699,10 @@ impl PyRules { } pub fn __getstate__(&self, py: Python) -> PyResult { - let tokenizer = self.tokenizer.borrow(py); // rwlock is serialized the same way as the inner type - let state = (&self.rules, tokenizer.tokenizer()); - Ok(PyBytes::new( py, - &bincode::serialize(&state) + &bincode::serialize(&self.rules) .map_err(|_| PyValueError::new_err("serializing state with `bincode` failed"))?, ) .to_object(py)) From 026ffb7f90759d96e0c9217d7abd0aaa74ba11ce Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 2 Mar 2021 16:33:59 +0100 Subject: [PATCH 04/16] treat spellchecker language variants properly --- build/README.md | 4 -- build/make_build_dir.py | 41 ++++++++++++--------- build/src/lib.rs | 9 +++-- nlprule/configs/de/tagger.json | 5 +++ nlprule/configs/en/tagger.json | 8 ++++ nlprule/configs/es/tagger.json | 3 +- nlprule/src/compile/impls.rs | 23 ++++++++---- nlprule/src/compile/mod.rs | 67 ++++++++++++++++++++-------------- nlprule/src/compile/utils.rs | 9 ++++- nlprule/src/tokenizer/tag.rs | 33 ++++++++--------- nlprule/src/types.rs | 33 +++++++++++------ 11 files changed, 142 insertions(+), 93 deletions(-) diff --git a/build/README.md b/build/README.md index 98d48a1..c826acc 100644 --- a/build/README.md +++ b/build/README.md @@ -79,8 +79,6 @@ python build/make_build_dir.py \ --chunker_token_model=$HOME/Downloads/nlprule/en-token.bin \ --chunker_pos_model=$HOME/Downloads/nlprule/en-pos-maxent.bin \ --chunker_chunk_model=$HOME/Downloads/nlprule/en-chunker.bin \ - --spell_dict_path=$LT_PATH/org/languagetool/resource/en/hunspell/en_GB.dict \ - --spell_info_path=$LT_PATH/org/languagetool/resource/en/hunspell/en_GB.info \ --out_dir=data/en ``` @@ -94,8 +92,6 @@ python build/make_build_dir.py \ --lang_code=de \ --tag_dict_path=$HOME/Downloads/nlprule/german-pos-dict/src/main/resources/org/languagetool/resource/de/german.dict \ --tag_info_path=$HOME/Downloads/nlprule/german-pos-dict/src/main/resources/org/languagetool/resource/de/german.info \ - --spell_dict_path=$LT_PATH/org/languagetool/resource/de/hunspell/de_DE.dict \ - --spell_info_path=$LT_PATH/org/languagetool/resource/de/hunspell/de_DE.info \ --out_dir=data/de ``` diff --git a/build/make_build_dir.py b/build/make_build_dir.py index 3b287f2..0cd357c 100644 --- a/build/make_build_dir.py +++ b/build/make_build_dir.py @@ -6,6 +6,7 @@ from zipfile import ZipFile import lxml.etree as ET import wordfreq +from glob import glob from chardet.universaldetector import UniversalDetector from chunker import write_chunker # type: ignore @@ -83,7 +84,7 @@ def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): dump_bytes = open(out_path, "rb").read() with open(out_path, "w") as f: - f.write(dump_bytes.decode(result["encoding"])) + f.write(dump_bytes.decode(result["encoding"] or "utf-8")) if __name__ == "__main__": @@ -119,16 +120,6 @@ def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): type=lambda p: Path(p).absolute(), help="Path to the accompanying tagger dictionary .info file.", ) - parser.add_argument( - "--spell_dict_path", - type=lambda p: Path(p).absolute(), - help="Path to a spell dictionary .dict file.", - ) - parser.add_argument( - "--spell_info_path", - type=lambda p: Path(p).absolute(), - help="Path to the accompanying spell dictionary .info file.", - ) parser.add_argument( "--chunker_token_model", default=None, @@ -168,13 +159,27 @@ def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): args.tag_info_path, ) - # spell dictionary - dump_dict( - args.out_dir / "spell.dump", - args.lt_dir, - args.spell_dict_path, - args.spell_info_path, - ) + # spell dictionaries + (args.out_dir / "spell").mkdir() + for dic in glob( + str( + args.lt_dir + / "org" + / "languagetool" + / "resource" + / args.lang_code + / "hunspell" + / "*.dict" + ) + ): + dic = Path(dic) + info = Path(dic).with_suffix(".info") + + variant_name = dic.stem + + dump_dict( + args.out_dir / "spell" / f"{variant_name}.dump", args.lt_dir, dic, info, + ) if ( args.chunker_token_model is not None diff --git a/build/src/lib.rs b/build/src/lib.rs index fbc18c9..281b3b9 100644 --- a/build/src/lib.rs +++ b/build/src/lib.rs @@ -5,7 +5,7 @@ use flate2::bufread::GzDecoder; use fs::File; use fs_err as fs; use nlprule::{compile, rules_filename, tokenizer_filename}; -use std::fs::Permissions; +use std::{fs::Permissions, sync::Arc}; use std::{ io::{self, BufReader, BufWriter, Cursor, Read}, path::{Path, PathBuf}, @@ -469,10 +469,11 @@ impl BinaryBuilder { let tokenizer_out = self.out_dir.join(tokenizer_filename(lang_code)); let rules_out = self.out_dir.join(rules_filename(lang_code)); - nlprule::Rules::new(rules_out) - .map_err(|e| Error::ValidationFailed(lang_code.to_owned(), Binary::Rules, e))?; - nlprule::Tokenizer::new(tokenizer_out) + let tokenizer = nlprule::Tokenizer::new(tokenizer_out) .map_err(|e| Error::ValidationFailed(lang_code.to_owned(), Binary::Tokenizer, e))?; + + nlprule::Rules::new(rules_out, Arc::new(tokenizer)) + .map_err(|e| Error::ValidationFailed(lang_code.to_owned(), Binary::Rules, e))?; } Ok(()) diff --git a/nlprule/configs/de/tagger.json b/nlprule/configs/de/tagger.json index 9e89b87..f18359f 100644 --- a/nlprule/configs/de/tagger.json +++ b/nlprule/configs/de/tagger.json @@ -4,5 +4,10 @@ "extra_tags": [ "PKT", "PRO:IND:DAT:SIN:NEU" + ], + "variants": [ + "de_AT", + "de_DE", + "de_CH" ] } \ No newline at end of file diff --git a/nlprule/configs/en/tagger.json b/nlprule/configs/en/tagger.json index 4b2ee6d..939f316 100644 --- a/nlprule/configs/en/tagger.json +++ b/nlprule/configs/en/tagger.json @@ -6,5 +6,13 @@ "ORD", "SYM", "RB_SENT" + ], + "variants": [ + "en_GB", + "en_US", + "en_ZA", + "en_NZ", + "en_CA", + "en_AU" ] } \ No newline at end of file diff --git a/nlprule/configs/es/tagger.json b/nlprule/configs/es/tagger.json index 3ab7978..4ec1458 100644 --- a/nlprule/configs/es/tagger.json +++ b/nlprule/configs/es/tagger.json @@ -60,5 +60,6 @@ "NCCN00", "LOC_CC", "LOC_I" - ] + ], + "variants": [] } \ No newline at end of file diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 92b8791..17f2a80 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -93,7 +93,7 @@ impl Tagger { pub(in crate::compile) fn from_dumps, S2: AsRef>( paths: &[S1], remove_paths: &[S2], - freq_words: HashMap, + spell_words: HashMap, lang_options: TaggerLangOptions, ) -> std::io::Result { let mut tags = DefaultHashMap::default(); @@ -115,17 +115,21 @@ impl Tagger { let punct = "!\"#$%&\\'()*+,-./:;<=>?@[\\]^_`{|}~"; for i in 0..punct.len() { - word_store.insert(&punct[i..(i + 1)], 0); + word_store.insert(&punct[i..(i + 1)], (0, 0)); } for (word, inflection, tag) in lines.iter() { - word_store.insert(word, 0); - word_store.insert(inflection, 0); + word_store.insert(word, (0, 0)); + word_store.insert(inflection, (0, 0)); tag_store.insert(tag); } - // extend with freq words at the end to make sure we overwrite words which existed but have 0 frequency - word_store.extend(freq_words.iter().map(|(word, freq)| (word.as_str(), *freq))); + // extend with spelling words at the end to make sure we overwrite words which existed but have 0 frequency + word_store.extend( + spell_words + .iter() + .map(|(word, freq)| (word.as_str(), *freq)), + ); // word store ids should be consistent across runs let mut word_store: Vec<_> = word_store.into_iter().collect(); @@ -138,7 +142,12 @@ impl Tagger { let word_store: BiMap<_, _> = word_store .iter() .enumerate() - .map(|(i, (word, freq))| ((*word).to_owned(), WordIdInt::new(i as u32, *freq))) + .map(|(i, (word, (freq, variants)))| { + ( + (*word).to_owned(), + WordIdInt::new(i as u32, *freq, *variants), + ) + }) .collect(); let tag_store: BiMap<_, _> = tag_store .iter() diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index f2897de..3f35e7b 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -37,7 +37,7 @@ struct BuildFilePaths { multiword_tag_path: PathBuf, regex_cache_path: PathBuf, srx_path: PathBuf, - spell_path: PathBuf, + spell_dir_path: PathBuf, } impl BuildFilePaths { @@ -54,7 +54,7 @@ impl BuildFilePaths { multiword_tag_path: p.join("tags/multiwords.txt"), regex_cache_path: p.join("regex_cache.bin"), srx_path: p.join("segment.srx"), - spell_path: p.join("spell.dump"), + spell_dir_path: p.join("spell"), } } } @@ -88,6 +88,40 @@ pub enum Error { Other(#[from] Box), } +fn parse_spell_dumps>( + spell_dir_path: P, + variants: &[String], +) -> Result, Error> { + let mut words = DefaultHashMap::new(); + + for (i, variant) in variants.iter().enumerate() { + let spell_path = spell_dir_path.as_ref().join(variant).with_extension("dump"); + info!("Reading spelling dictionary from {}.", spell_path.display()); + + let reader = BufReader::new(File::open(spell_path)?); + for line in reader.lines() { + match line? + .trim() + .split_whitespace() + .collect::>() + .as_slice() + { + [freq, word] => { + // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. + // we start from 1 because 0 is reserved for words we do not know the frequency of + let freq = 1 + freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; + assert!(freq < u8::MAX as usize); + assert!(i < 8); + words.insert(word.to_string(), (1u8 << i, freq as u8)); + } + _ => continue, + } + } + } + + Ok(words) +} + /// Compiles the binaries from a build directory. pub fn compile( build_dir: impl AsRef, @@ -98,31 +132,6 @@ pub fn compile( let lang_code = fs::read_to_string(paths.lang_code_path)?; - info!( - "Reading spelling words with frequency from {}.", - paths.spell_path.display() - ); - let mut freq_words = DefaultHashMap::new(); - let reader = BufReader::new(File::open(paths.spell_path)?); - - for line in reader.lines() { - match line? - .trim() - .split_whitespace() - .collect::>() - .as_slice() - { - [freq, word] => { - // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. - // we start from 1 because 0 is reserved for words we do not know the frequency of - let freq = 1 + freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; - assert!(freq < u8::MAX as usize); - freq_words.insert(word.to_string(), freq as u8); - } - _ => continue, - } - } - let tokenizer_lang_options = utils::tokenizer_lang_options(&lang_code).ok_or_else(|| { Error::LanguageOptionsDoNotExist { lang_code: lang_code.clone(), @@ -139,11 +148,13 @@ pub fn compile( lang_code: lang_code.clone(), })?; + let words = parse_spell_dumps(&paths.spell_dir_path, &tagger_lang_options.variants)?; + info!("Creating tagger."); let tagger = Tagger::from_dumps( &paths.tag_paths, &paths.tag_remove_paths, - freq_words, + words, tagger_lang_options, )?; diff --git a/nlprule/src/compile/utils.rs b/nlprule/src/compile/utils.rs index 73b5322..f834730 100644 --- a/nlprule/src/compile/utils.rs +++ b/nlprule/src/compile/utils.rs @@ -47,7 +47,14 @@ pub(crate) fn rules_lang_options(lang_code: &str) -> Option { /// Gets the tagger language options for the language code pub(crate) fn tagger_lang_options(lang_code: &str) -> Option { - TAGGER_LANG_OPTIONS.get(lang_code).cloned() + TAGGER_LANG_OPTIONS + .get(lang_code) + .cloned() + .map(|mut options| { + // lang_code on the tagger is special; populated automatically + options.lang_code = lang_code.to_owned(); + options + }) } pub(crate) use regex::from_java_regex; diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 87b6c32..4102be4 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -9,7 +9,7 @@ use log::error; use serde::{Deserialize, Serialize}; use std::{borrow::Cow, iter::once}; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, Serialize, Deserialize)] pub(crate) struct TaggerLangOptions { /// Whether to use a heuristic to split potential compound words. pub use_compound_split_heuristic: bool, @@ -17,16 +17,11 @@ pub(crate) struct TaggerLangOptions { pub always_add_lower_tags: bool, /// Used part-of-speech tags which are not in the tagger dictionary. pub extra_tags: Vec, -} - -impl Default for TaggerLangOptions { - fn default() -> Self { - TaggerLangOptions { - use_compound_split_heuristic: false, - always_add_lower_tags: false, - extra_tags: Vec::new(), - } - } + /// Variants of the language (e.g. "en_US", "en_GB") to consider for spellchecking. + pub variants: Vec, + /// The language code in two-letter format. Set automatically by the compile module. + #[serde(skip)] + pub lang_code: String, } #[derive(Serialize, Deserialize)] @@ -52,15 +47,15 @@ impl From for TaggerFields { let key: Vec = word.as_bytes().iter().chain(once(&i)).copied().collect(); let pos_bytes = pos_id.0.to_be_bytes(); - let inflect_bytes = inflect_id.raw_value().to_be_bytes(); + let inflect_bytes = inflect_id.to_bytes(); let value = u64::from_be_bytes([ inflect_bytes[0], inflect_bytes[1], inflect_bytes[2], inflect_bytes[3], - 0, - 0, + inflect_bytes[4], + inflect_bytes[5], pos_bytes[0], pos_bytes[1], ]); @@ -74,7 +69,7 @@ impl From for TaggerFields { let mut word_store_items: Vec<_> = tagger .word_store .iter() - .map(|(key, value)| (key.clone(), value.raw_value() as u64)) + .map(|(key, value)| (key.clone(), value.to_u64())) .collect(); word_store_items.sort_by(|(a, _), (b, _)| a.cmp(b)); @@ -106,7 +101,7 @@ impl From for Tagger { .into_str_vec() .unwrap() .into_iter() - .map(|(key, value)| (key, WordIdInt::from_raw_value(value as u32))) + .map(|(key, value)| (key, WordIdInt::from_u64(value))) .collect(); let mut tags = DefaultHashMap::new(); @@ -120,12 +115,14 @@ impl From for Tagger { let word_id = *word_store.get_by_left(word).unwrap(); let value_bytes = value.to_be_bytes(); - let inflection_id = WordIdInt::from_raw_value(u32::from_be_bytes([ + let inflection_id = WordIdInt::from_bytes([ value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - ])); + value_bytes[4], + value_bytes[5], + ]); let pos_id = PosIdInt(u16::from_be_bytes([value_bytes[6], value_bytes[7]])); let group = groups.entry(inflection_id).or_insert_with(Vec::new); diff --git a/nlprule/src/types.rs b/nlprule/src/types.rs index db6dab9..7c94ca6 100644 --- a/nlprule/src/types.rs +++ b/nlprule/src/types.rs @@ -16,28 +16,37 @@ pub(crate) type DefaultHasher = hash_map::DefaultHasher; #[derive( Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd, Default, )] -#[serde(transparent)] -pub(crate) struct WordIdInt(u32); +pub(crate) struct WordIdInt(u32, u8, u8); impl WordIdInt { - pub fn new(index: u32, freq: u8) -> Self { - assert!(index < 2u32.pow(24)); - - let mut id = index << 8; - id |= freq as u32; - WordIdInt(id) + pub fn new(index: u32, freq: u8, variants: u8) -> Self { + WordIdInt(index, freq, variants) } pub fn freq(&self) -> u8 { (self.0 & 255) as u8 } - pub fn raw_value(&self) -> u32 { - self.0 + pub fn to_u64(&self) -> u64 { + let b = self.to_bytes(); + u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], self.1, self.2]) + } + + pub fn from_u64(value: u64) -> Self { + let b = value.to_be_bytes(); + Self::from_bytes([b[0], b[1], b[2], b[3], b[4], b[5]]) + } + + pub fn to_bytes(&self) -> [u8; 6] { + let b = self.0.to_be_bytes(); + [b[0], b[1], b[2], b[3], self.1, self.2] } - pub fn from_raw_value(id: u32) -> Self { - WordIdInt(id) + pub fn from_bytes(value: [u8; 6]) -> Self { + let index = u32::from_be_bytes([value[0], value[1], value[2], value[3]]); + let freq = value[4]; + let variants = value[5]; + WordIdInt(index, freq, variants) } } From a0be3c2e13e0a7a829cea93643a228b97bdfa839 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Fri, 5 Mar 2021 13:49:24 +0100 Subject: [PATCH 05/16] switch to fst based spelling correction --- nlprule/Cargo.toml | 3 +- nlprule/build.rs | 1 + nlprule/configs/de/spellchecker.json | 7 + nlprule/configs/de/tagger.json | 5 - nlprule/configs/en/spellchecker.json | 9 + nlprule/configs/en/tagger.json | 8 - nlprule/configs/es/spellchecker.json | 3 + nlprule/configs/es/tagger.json | 3 +- nlprule/src/compile/impls.rs | 97 ++++++--- nlprule/src/compile/mod.rs | 60 ++---- nlprule/src/compile/utils.rs | 29 ++- nlprule/src/lib.rs | 2 +- nlprule/src/rules.rs | 69 ++---- nlprule/src/spellcheck/levenshtein.rs | 114 ++++++++++ nlprule/src/spellcheck/mod.rs | 297 +++++++++++++------------- nlprule/src/tokenizer/tag.rs | 25 +-- nlprule/src/types.rs | 34 +-- nlprule/tests/tests.rs | 21 +- 18 files changed, 449 insertions(+), 338 deletions(-) create mode 100644 nlprule/configs/de/spellchecker.json create mode 100644 nlprule/configs/en/spellchecker.json create mode 100644 nlprule/configs/es/spellchecker.json create mode 100644 nlprule/src/spellcheck/levenshtein.rs diff --git a/nlprule/Cargo.toml b/nlprule/Cargo.toml index f397430..6571798 100644 --- a/nlprule/Cargo.toml +++ b/nlprule/Cargo.toml @@ -30,8 +30,7 @@ half = { version = "1.7", features = ["serde"] } srx = { version = "^0.1.2", features = ["serde"] } lazycell = "1" cfg-if = "1" -triple_accel = "0.3" -appendlist = "1.4" +fnv = "1" rayon-cond = "0.1" rayon = "1.5" diff --git a/nlprule/build.rs b/nlprule/build.rs index 8eb2ad1..543bec2 100644 --- a/nlprule/build.rs +++ b/nlprule/build.rs @@ -20,6 +20,7 @@ fn main() { ("tokenizer.json", "tokenizer_configs.json"), ("rules.json", "rules_configs.json"), ("tagger.json", "tagger_configs.json"), + ("spellchecker.json", "spellchecker_configs.json"), ] { let mut config_map: HashMap = HashMap::new(); diff --git a/nlprule/configs/de/spellchecker.json b/nlprule/configs/de/spellchecker.json new file mode 100644 index 0000000..52f6a41 --- /dev/null +++ b/nlprule/configs/de/spellchecker.json @@ -0,0 +1,7 @@ +{ + "variants": [ + "de_AT", + "de_DE", + "de_CH" + ] +} \ No newline at end of file diff --git a/nlprule/configs/de/tagger.json b/nlprule/configs/de/tagger.json index f18359f..9e89b87 100644 --- a/nlprule/configs/de/tagger.json +++ b/nlprule/configs/de/tagger.json @@ -4,10 +4,5 @@ "extra_tags": [ "PKT", "PRO:IND:DAT:SIN:NEU" - ], - "variants": [ - "de_AT", - "de_DE", - "de_CH" ] } \ No newline at end of file diff --git a/nlprule/configs/en/spellchecker.json b/nlprule/configs/en/spellchecker.json new file mode 100644 index 0000000..ab2e94b --- /dev/null +++ b/nlprule/configs/en/spellchecker.json @@ -0,0 +1,9 @@ +{ + "variants": [ + "en_GB", + "en_US", + "en_ZA", + "en_CA", + "en_AU" + ] +} \ No newline at end of file diff --git a/nlprule/configs/en/tagger.json b/nlprule/configs/en/tagger.json index 939f316..4b2ee6d 100644 --- a/nlprule/configs/en/tagger.json +++ b/nlprule/configs/en/tagger.json @@ -6,13 +6,5 @@ "ORD", "SYM", "RB_SENT" - ], - "variants": [ - "en_GB", - "en_US", - "en_ZA", - "en_NZ", - "en_CA", - "en_AU" ] } \ No newline at end of file diff --git a/nlprule/configs/es/spellchecker.json b/nlprule/configs/es/spellchecker.json new file mode 100644 index 0000000..de221b3 --- /dev/null +++ b/nlprule/configs/es/spellchecker.json @@ -0,0 +1,3 @@ +{ + "variants": [] +} \ No newline at end of file diff --git a/nlprule/configs/es/tagger.json b/nlprule/configs/es/tagger.json index 4ec1458..3ab7978 100644 --- a/nlprule/configs/es/tagger.json +++ b/nlprule/configs/es/tagger.json @@ -60,6 +60,5 @@ "NCCN00", "LOC_CC", "LOC_I" - ], - "variants": [] + ] } \ No newline at end of file diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 17f2a80..94f87b4 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -4,11 +4,12 @@ use indexmap::IndexMap; use log::warn; use serde::{Deserialize, Serialize}; use std::{ + cmp, collections::{HashMap, HashSet}, hash::{Hash, Hasher}, io::{self, BufRead, BufReader}, path::Path, - sync::Arc, + sync::{atomic::AtomicUsize, Arc}, }; use crate::{ @@ -22,6 +23,7 @@ use crate::{ DisambiguationRule, MatchGraph, Rule, }, rules::{Rules, RulesLangOptions, RulesOptions}, + spellcheck::{SpellInt, Spellchecker, SpellcheckerLangOptions}, tokenizer::{ chunk, multiword::{MultiwordTagger, MultiwordTaggerFields}, @@ -34,6 +36,58 @@ use crate::{ use super::{parse_structure::BuildInfo, Error}; +impl Spellchecker { + pub(in crate::compile) fn from_dumps>( + spell_dir_path: S, + lang_options: SpellcheckerLangOptions, + ) -> io::Result { + let mut words: HashMap = DefaultHashMap::new(); + let mut max_freq = 0; + + for (i, variant) in lang_options.variants.iter().enumerate() { + let spell_path = spell_dir_path.as_ref().join(variant).with_extension("dump"); + + let reader = BufReader::new(File::open(spell_path)?); + for line in reader.lines() { + match line? + .trim() + .split_whitespace() + .collect::>() + .as_slice() + { + [freq, word] => { + // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. + let freq = freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; + let value = words.entry(word.to_string()).or_default(); + + max_freq = cmp::max(max_freq, freq); + + value.update_freq(freq); + value.add_variant(i); + } + _ => continue, + } + } + } + let mut words: Vec<_> = words + .into_iter() + .map(|(key, value)| (key, value.as_u64())) + .collect(); + words.sort_by(|(a, _), (b, _)| a.cmp(b)); + + let map = + fst::Map::from_iter(words.into_iter()).expect("words are lexicographically sorted."); + + Ok(Spellchecker { + fst: map.into_fst().to_vec(), + max_freq, + lang_options, + used_variant: Arc::new(AtomicUsize::new(usize::MAX)), + used_fst: Arc::new(Vec::new().into()), + }) + } +} + impl Tagger { fn get_lines, S2: AsRef>( paths: &[S1], @@ -93,14 +147,14 @@ impl Tagger { pub(in crate::compile) fn from_dumps, S2: AsRef>( paths: &[S1], remove_paths: &[S2], - spell_words: HashMap, + common_words: &HashSet, lang_options: TaggerLangOptions, ) -> std::io::Result { let mut tags = DefaultHashMap::default(); let mut groups = DefaultHashMap::default(); let mut tag_store = HashSet::new(); - let mut word_store = HashMap::new(); + let mut word_store = HashSet::new(); // hardcoded special tags tag_store.insert(""); @@ -115,39 +169,29 @@ impl Tagger { let punct = "!\"#$%&\\'()*+,-./:;<=>?@[\\]^_`{|}~"; for i in 0..punct.len() { - word_store.insert(&punct[i..(i + 1)], (0, 0)); + word_store.insert(&punct[i..(i + 1)]); } + word_store.extend(common_words.iter().map(|x| x.as_str())); + for (word, inflection, tag) in lines.iter() { - word_store.insert(word, (0, 0)); - word_store.insert(inflection, (0, 0)); + word_store.insert(word); + word_store.insert(inflection); tag_store.insert(tag); } - // extend with spelling words at the end to make sure we overwrite words which existed but have 0 frequency - word_store.extend( - spell_words - .iter() - .map(|(word, freq)| (word.as_str(), *freq)), - ); - // word store ids should be consistent across runs - let mut word_store: Vec<_> = word_store.into_iter().collect(); - word_store.sort_unstable(); + let mut word_store: Vec<_> = word_store.iter().collect(); + word_store.sort(); - // tag store ids should be consistent across runs - let mut tag_store: Vec<_> = tag_store.into_iter().collect(); - tag_store.sort_unstable(); + // tag store ids should be consistent across runs + let mut tag_store: Vec<_> = tag_store.iter().collect(); + tag_store.sort(); let word_store: BiMap<_, _> = word_store .iter() .enumerate() - .map(|(i, (word, (freq, variants)))| { - ( - (*word).to_owned(), - WordIdInt::new(i as u32, *freq, *variants), - ) - }) + .map(|(i, x)| (x.to_string(), WordIdInt(i as u32))) .collect(); let tag_store: BiMap<_, _> = tag_store .iter() @@ -274,6 +318,7 @@ impl Rules { pub(in crate::compile) fn from_xml>( path: P, build_info: &mut BuildInfo, + spellchecker: Spellchecker, tokenizer: Arc, options: RulesLangOptions, ) -> Self { @@ -369,9 +414,9 @@ impl Rules { Rules { rules, - options: RulesOptions::default(), - spellchecker: None, + spellchecker, tokenizer, + options: RulesOptions::default(), } } } diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index 3f35e7b..365aa78 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -5,7 +5,7 @@ use fs_err as fs; use std::{ hash::{Hash, Hasher}, - io::{self, BufRead, BufReader, BufWriter}, + io::{self, BufReader, BufWriter}, num::ParseIntError, path::{Path, PathBuf}, str::FromStr, @@ -14,6 +14,7 @@ use std::{ use crate::{ rules::Rules, + spellcheck::Spellchecker, tokenizer::{chunk::Chunker, multiword::MultiwordTagger, tag::Tagger, Tokenizer}, types::*, }; @@ -37,6 +38,7 @@ struct BuildFilePaths { multiword_tag_path: PathBuf, regex_cache_path: PathBuf, srx_path: PathBuf, + common_words_path: PathBuf, spell_dir_path: PathBuf, } @@ -54,6 +56,7 @@ impl BuildFilePaths { multiword_tag_path: p.join("tags/multiwords.txt"), regex_cache_path: p.join("regex_cache.bin"), srx_path: p.join("segment.srx"), + common_words_path: p.join("common.txt"), spell_dir_path: p.join("spell"), } } @@ -88,40 +91,6 @@ pub enum Error { Other(#[from] Box), } -fn parse_spell_dumps>( - spell_dir_path: P, - variants: &[String], -) -> Result, Error> { - let mut words = DefaultHashMap::new(); - - for (i, variant) in variants.iter().enumerate() { - let spell_path = spell_dir_path.as_ref().join(variant).with_extension("dump"); - info!("Reading spelling dictionary from {}.", spell_path.display()); - - let reader = BufReader::new(File::open(spell_path)?); - for line in reader.lines() { - match line? - .trim() - .split_whitespace() - .collect::>() - .as_slice() - { - [freq, word] => { - // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. - // we start from 1 because 0 is reserved for words we do not know the frequency of - let freq = 1 + freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; - assert!(freq < u8::MAX as usize); - assert!(i < 8); - words.insert(word.to_string(), (1u8 << i, freq as u8)); - } - _ => continue, - } - } - } - - Ok(words) -} - /// Compiles the binaries from a build directory. pub fn compile( build_dir: impl AsRef, @@ -132,6 +101,15 @@ pub fn compile( let lang_code = fs::read_to_string(paths.lang_code_path)?; + info!( + "Reading common words from {}.", + paths.common_words_path.display() + ); + let common_words = fs::read_to_string(paths.common_words_path)? + .lines() + .map(|x| x.to_string()) + .collect(); + let tokenizer_lang_options = utils::tokenizer_lang_options(&lang_code).ok_or_else(|| { Error::LanguageOptionsDoNotExist { lang_code: lang_code.clone(), @@ -148,13 +126,18 @@ pub fn compile( lang_code: lang_code.clone(), })?; - let words = parse_spell_dumps(&paths.spell_dir_path, &tagger_lang_options.variants)?; + let spellchecker_lang_options = + utils::spellchecker_lang_options(&lang_code).ok_or_else(|| { + Error::LanguageOptionsDoNotExist { + lang_code: lang_code.clone(), + } + })?; info!("Creating tagger."); let tagger = Tagger::from_dumps( &paths.tag_paths, &paths.tag_remove_paths, - words, + &common_words, tagger_lang_options, )?; @@ -206,6 +189,8 @@ pub fn compile( None }; + let spellchecker = Spellchecker::from_dumps(paths.spell_dir_path, spellchecker_lang_options)?; + info!("Creating tokenizer."); let tokenizer = Tokenizer::from_xml( &paths.disambiguation_path, @@ -221,6 +206,7 @@ pub fn compile( let rules = Rules::from_xml( &paths.grammar_path, &mut build_info, + spellchecker, Arc::new(tokenizer), rules_lang_options, ); diff --git a/nlprule/src/compile/utils.rs b/nlprule/src/compile/utils.rs index f834730..d4d5acf 100644 --- a/nlprule/src/compile/utils.rs +++ b/nlprule/src/compile/utils.rs @@ -1,4 +1,6 @@ -use crate::{rules::RulesLangOptions, tokenizer::TokenizerLangOptions}; +use crate::{ + rules::RulesLangOptions, spellcheck::SpellcheckerLangOptions, tokenizer::TokenizerLangOptions, +}; use crate::{tokenizer::tag::TaggerLangOptions, types::*}; use lazy_static::lazy_static; @@ -35,6 +37,17 @@ lazy_static! { }; } +lazy_static! { + static ref SPELLCHECKER_LANG_OPTIONS: DefaultHashMap = { + serde_json::from_slice(include_bytes!(concat!( + env!("OUT_DIR"), + "/", + "spellchecker_configs.json" + ))) + .expect("tagger configs must be valid JSON") + }; +} + /// Gets the tokenizer language options for the language code pub(crate) fn tokenizer_lang_options(lang_code: &str) -> Option { TOKENIZER_LANG_OPTIONS.get(lang_code).cloned() @@ -47,14 +60,12 @@ pub(crate) fn rules_lang_options(lang_code: &str) -> Option { /// Gets the tagger language options for the language code pub(crate) fn tagger_lang_options(lang_code: &str) -> Option { - TAGGER_LANG_OPTIONS - .get(lang_code) - .cloned() - .map(|mut options| { - // lang_code on the tagger is special; populated automatically - options.lang_code = lang_code.to_owned(); - options - }) + TAGGER_LANG_OPTIONS.get(lang_code).cloned() +} + +/// Gets the spellchecker language options for the language code +pub(crate) fn spellchecker_lang_options(lang_code: &str) -> Option { + SPELLCHECKER_LANG_OPTIONS.get(lang_code).cloned() } pub(crate) use regex::from_java_regex; diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 79215a5..10b443d 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -63,7 +63,7 @@ pub mod compile; mod filter; pub mod rule; pub mod rules; -mod spellcheck; +pub mod spellcheck; pub mod tokenizer; pub mod types; pub(crate) mod utils; diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index 7022217..9c20000 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -3,48 +3,22 @@ use crate::{rule::id::Selector, tokenizer::Tokenizer}; use crate::{rule::Rule, Error}; use crate::{ - spellcheck::SpellcheckOptions, spellcheck::Spellchecker, types::*, + spellcheck::Spellchecker, spellcheck::SpellcheckerOptions, types::*, utils::parallelism::MaybeParallelRefIterator, }; use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ io::{BufReader, Read, Write}, - ops::{Deref, DerefMut}, path::Path, sync::Arc, }; /// Options for a rule set. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] pub struct RulesOptions { - spellcheck: bool, - spellcheck_options: SpellcheckOptions, -} - -/// TODO -pub struct RulesOptionsGuard<'a> { - rules: &'a mut Rules, -} - -impl<'a> Deref for RulesOptionsGuard<'a> { - type Target = RulesOptions; - - fn deref(&self) -> &Self::Target { - &self.rules.options - } -} - -impl<'a> DerefMut for RulesOptionsGuard<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.rules.options - } -} - -impl<'a> Drop for RulesOptionsGuard<'a> { - fn drop(&mut self) { - self.rules.ingest_options() - } + /// TODO + pub spellchecker_options: SpellcheckerOptions, } /// Language-dependent options for a rule set. @@ -73,11 +47,15 @@ impl Default for RulesLangOptions { #[derive(Serialize, Deserialize, Default)] struct RulesFields { pub(crate) rules: Vec, + pub(crate) spellchecker: Spellchecker, } impl From for RulesFields { fn from(rules: Rules) -> Self { - RulesFields { rules: rules.rules } + RulesFields { + rules: rules.rules, + spellchecker: rules.spellchecker, + } } } @@ -85,21 +63,12 @@ impl From for RulesFields { #[derive(Clone, Default, Serialize, Deserialize)] pub struct Rules { pub(crate) rules: Vec, - pub(crate) options: RulesOptions, - pub(crate) spellchecker: Option, + pub(crate) spellchecker: Spellchecker, pub(crate) tokenizer: Arc, + pub(crate) options: RulesOptions, } impl Rules { - fn ingest_options(&mut self) { - if self.options.spellcheck && self.spellchecker.is_none() { - self.spellchecker = Some(Spellchecker::new( - &self.tokenizer.tagger(), - self.options.spellcheck_options.clone(), - )); - } - } - /// TODO pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { // TODO: the .clone() here could be avoided @@ -115,13 +84,12 @@ impl Rules { options: RulesOptions, ) -> Result { let fields: RulesFields = bincode::deserialize_from(reader)?; - let mut rules = Rules { + let rules = Rules { rules: fields.rules, options, - spellchecker: None, + spellchecker: fields.spellchecker, tokenizer, }; - rules.ingest_options(); Ok(rules) } @@ -155,8 +123,8 @@ impl Rules { &self.options } - /// Gets the options of this rule set (mutable). - pub fn options_mut(&mut self) -> &mut RulesOptions { + /// Sets the options of this rule set. + pub fn mut_options(&mut self) -> &mut RulesOptions { &mut self.options } @@ -209,6 +177,13 @@ impl Rules { .flatten() .collect(); + output.extend( + self.spellchecker + .suggest(tokens, &self.options.spellchecker_options) + .into_iter() + .map(|x| (0, x)), + ); + output.sort_by(|(ia, a), (ib, b)| a.start.cmp(&b.start).then_with(|| ib.cmp(ia))); let mut mask = vec![false; tokens[0].sentence.chars().count()]; diff --git a/nlprule/src/spellcheck/levenshtein.rs b/nlprule/src/spellcheck/levenshtein.rs new file mode 100644 index 0000000..18f42ea --- /dev/null +++ b/nlprule/src/spellcheck/levenshtein.rs @@ -0,0 +1,114 @@ +use fnv::FnvHasher; +use fst::Automaton; +use std::{ + cmp::{self, min}, + hash::{Hash, Hasher}, +}; + +#[derive(Clone, Debug)] +pub struct LevenshteinState { + dist: usize, + n: usize, + row: Vec, + hash: u64, +} + +impl LevenshteinState { + pub fn dist(&self) -> usize { + self.dist + } +} + +#[derive(Debug, Clone)] +pub struct Levenshtein<'a> { + query: &'a [u8], + distance: usize, + prefix: usize, +} + +impl<'a> Levenshtein<'a> { + pub fn new(query: &'a str, distance: usize, prefix: usize) -> Self { + Levenshtein { + query: query.as_bytes(), + distance, + prefix, + } + } +} + +impl<'a> Automaton for Levenshtein<'a> { + type State = Option; + + fn start(&self) -> Self::State { + Some(LevenshteinState { + dist: self.query.len(), + n: 0, + row: (0..=self.query.len()).collect(), + hash: FnvHasher::default().finish(), + }) + } + + fn is_match(&self, state: &Self::State) -> bool { + state + .as_ref() + .map_or(false, |state| state.dist <= self.distance) + } + + fn can_match(&self, state: &Self::State) -> bool { + state.is_some() + } + + fn accept(&self, state: &Self::State, byte: u8) -> Self::State { + state.as_ref().and_then(|state| { + let mut next_hasher = FnvHasher::with_key(state.hash); + byte.hash(&mut next_hasher); + let next_hash = next_hasher.finish(); + + let prev_row = &state.row; + let mut next_row = state.row.to_vec(); + + next_row[0] = state.n + 1; + + for i in 1..next_row.len() { + let cost = if byte == self.query[i - 1] { + prev_row[i - 1] + } else { + min( + next_row[i - 1] + 1, + min(prev_row[i - 1] + 1, prev_row[i] + 1), + ) + }; + next_row[i] = cost; + } + + let distance = if state.n >= self.prefix { + self.distance + } else { + 1 + }; + + let lower_bound = state.n.saturating_sub(distance); + let upper_bound = cmp::min(state.n + distance, self.query.len()); + + let cutoff = if lower_bound > upper_bound { + 0 + } else { + *next_row[lower_bound..=upper_bound] + .iter() + .min() + .unwrap_or(&0) + }; + + if cutoff > distance { + return None; + } + + Some(LevenshteinState { + dist: next_row[self.query.len()], + n: state.n + 1, + row: next_row, + hash: next_hash, + }) + }) + } +} diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index 72a70c5..34b9c07 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -1,198 +1,189 @@ use std::{ - cmp, - hash::{Hash, Hasher}, + ops::Deref, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, RwLockReadGuard, + }, }; -use appendlist::AppendList; -use lazy_static::lazy_static; +use fst::{IntoStreamer, Map, MapBuilder, Streamer}; use serde::{Deserialize, Serialize}; -use triple_accel::levenshtein; -use crate::{tokenizer::tag::Tagger, types::*}; +use crate::types::*; -fn hash(string: H) -> u64 { - let mut hasher = DefaultHasher::new(); - string.hash(&mut hasher); - hasher.finish() -} +mod levenshtein; + +#[derive(Debug, Clone, Default, Copy)] +pub(crate) struct SpellInt(u64); + +impl SpellInt { + pub fn as_u64(&self) -> u64 { + self.0 + } + + pub fn update_freq(&mut self, freq: usize) { + assert!(freq < u32::MAX as usize); + + // erase previous frequency + self.0 = self.0 & (u64::MAX - u32::MAX as u64); + // set new frequency + self.0 |= freq as u64; + } + + pub fn add_variant(&mut self, index: usize) { + assert!(index < 32); + self.0 |= 1 << (32 + index); + } -fn distance(a: usize, b: usize) -> usize { - if a > b { - a - b - } else { - b - a + pub fn contains_variant(&self, index: usize) -> bool { + (self.0 >> (32 + index)) & 1 == 1 + } + + pub fn freq(&self) -> usize { + (self.0 & u32::MAX as u64) as usize } } -#[derive(Debug, Clone, Default, PartialEq, PartialOrd, Ord, Eq)] -pub struct Candidate<'a> { - pub distance: usize, - pub freq: u8, - pub term: &'a str, +#[derive(Debug, Clone, Default, PartialEq, PartialOrd)] +struct Candidate { + pub score: f32, + pub term: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SpellcheckOptions { - max_dictionary_distance: usize, - prefix_length: usize, +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// TODO +pub struct SpellcheckerOptions { + pub variant: Option, + pub max_distance: usize, + pub prefix: usize, + pub frequency_weight: f32, + pub n_suggestions: usize, } -impl Default for SpellcheckOptions { +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] +pub(crate) struct SpellcheckerLangOptions { + /// Variants of the language (e. g. "en_US", "en_GB") to consider for spellchecking. + pub variants: Vec, +} + +impl Default for SpellcheckerOptions { fn default() -> Self { - SpellcheckOptions { - max_dictionary_distance: 2, - prefix_length: 7, + SpellcheckerOptions { + variant: None, + max_distance: 2, + prefix: 2, + frequency_weight: 2., + n_suggestions: 10, } } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub struct Spellchecker { - deletes: DefaultHashMap>, - max_length: usize, - options: SpellcheckOptions, -} - -lazy_static! { - static ref EMPTY_HASH: u64 = { - let empty: &[u8] = &[]; - hash(empty) - }; +pub(crate) struct Spellchecker { + pub(crate) fst: Vec, + pub(crate) max_freq: usize, + pub(crate) lang_options: SpellcheckerLangOptions, + pub(crate) used_variant: Arc, + pub(crate) used_fst: Arc>>, } impl Spellchecker { - fn deletes_prefix(key: &[u8], options: &SpellcheckOptions) -> DefaultHashSet { - let mut out = DefaultHashSet::new(); - - if key.len() <= options.max_dictionary_distance { - out.insert(*EMPTY_HASH); - } - - Self::deletes_recurse(&key[..options.prefix_length], 0, &mut out, options); - out.insert(hash(key)); - - out - } - - fn deletes_recurse( - bytes: &[u8], - distance: usize, - out: &mut DefaultHashSet, - options: &SpellcheckOptions, - ) { - if bytes.len() > 1 { - for i in 0..bytes.len() { - let mut delete = bytes.to_vec(); - delete.remove(i); - - if distance + 1 < options.max_dictionary_distance { - Self::deletes_recurse(&delete, distance + 1, out, options); + fn update_used_fst( + &self, + options: &SpellcheckerOptions, + ) -> Option>> { + let variant = if let Some(variant) = options.variant.as_ref() { + variant.as_str() + } else { + return None; + }; + let variant_index = self + .lang_options + .variants + .iter() + .position(|x| x == variant)?; + + if self.used_variant.swap(variant_index, Ordering::Relaxed) != variant_index { + let mut used_fst_builder = MapBuilder::memory(); + + let fst = Map::new(&self.fst).expect("serialized fst must be valid."); + let mut stream = fst.into_stream(); + + while let Some((k, v)) = stream.next() { + if SpellInt(v).contains_variant(variant_index) { + used_fst_builder + .insert(k, v) + .expect("fst stream returns values in lexicographic order."); } - - out.insert(hash(delete)); } - } - } - pub fn new(tagger: &Tagger, options: SpellcheckOptions) -> Self { - let mut deletes = DefaultHashMap::new(); - let mut max_length = 0; + let mut guard = self.used_fst.write(); + let used_fst = guard.as_deref_mut().expect("lock must not be poisoned."); - for (word, id) in tagger.word_store() { - for delete in Self::deletes_prefix(word.as_bytes(), &options) { - deletes.entry(delete).or_insert_with(Vec::new).push(*id); - } - max_length = cmp::max(word.len(), max_length); + *used_fst = used_fst_builder + .into_inner() + .expect("subset of valid fst must be valid."); } - Spellchecker { - deletes, - max_length, - options, - } + Some(self.used_fst.read().expect("lock must not be poisoned")) } - pub fn lookup<'t>(&self, token: &'t Token, max_distance: usize) -> Option>> { - if token.word.text.id().is_some() { - return None; - } + fn lookup(&self, token: &Token, options: &SpellcheckerOptions) -> Option> { + let guard = self.update_used_fst(options)?; + let used_fst = Map::new(guard.deref()).expect("used fst must be valid."); - let word = token.word.text.0.as_ref(); - let input_length = word.len(); - - if input_length - self.options.max_dictionary_distance > self.max_length { - return Some(Vec::new()); + let text = token.word.text.as_ref(); + // no text => nothing to correct, only the case for special tokens (e.g. SENT_START) + if text.is_empty() { + return None; } - let mut candidates = Vec::new(); - - // deletes we've considered already - let mut known_deletes: DefaultHashSet = DefaultHashSet::new(); - // suggestions we've considered already - let mut known_suggestions: DefaultHashSet = DefaultHashSet::new(); - - let mut candidate_index = 0; - let deletes: AppendList> = AppendList::new(); - - let input_prefix_length = cmp::min(input_length, self.options.prefix_length); - deletes.push(word.as_bytes()[..input_prefix_length].to_vec()); + let query = levenshtein::Levenshtein::new(text, options.max_distance, 2); - while candidate_index < deletes.len() { - let candidate = deletes[candidate_index].as_slice(); - let candidate_length = candidate.len(); + let mut out = Vec::new(); - candidate_index += 1; - - let length_diff = input_prefix_length - candidate_length; - - if let Some(suggestions) = self.deletes.get(&hash(candidate)) { - for suggestion_id in suggestions { - let suggestion = token.tagger.str_for_word_id(suggestion_id); - let suggestion_length = suggestion.len(); - - if distance(suggestion_length, input_length) > max_distance - // suggestion must be for a different delete string, in same bin only because of hash collision - || suggestion_length < candidate_length - // in the same bin only because of hash collision, a valid suggestion is always longer than the delete - || suggestion_length == candidate_length - // we already added the suggestion - || known_suggestions.contains(suggestion_id) - { - continue; - } - - // SymSpell.cs covers some additional cases here where it is not necessary to compute the edit distance - // would have to be benchmarked if they are worth it considering `triple_accel` is presumably faster than - // the C# implementation of edit distance - - let distance = levenshtein(suggestion.as_bytes(), word.as_bytes()) as usize; - let freq = suggestion_id.freq(); - - candidates.push(Candidate { - term: suggestion, - distance, - freq, - }); - known_suggestions.insert(*suggestion_id); - } + let mut stream = used_fst.search_with_state(query).into_stream(); + while let Some((k, v, s)) = stream.next() { + let state = s.expect("matching levenshtein state is always `Some`."); + if state.dist() == 0 { + return None; } - if length_diff < max_distance && candidate_length <= self.options.prefix_length { - for i in 0..candidate.len() { - let mut delete = candidate.to_owned(); - delete.remove(i); + let id = SpellInt(v); - let delete_hash = hash(&delete); + let string = String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."); + out.push(Candidate { + score: (options.max_distance - state.dist()) as f32 + + id.freq() as f32 / self.max_freq as f32 * options.frequency_weight, + term: string, + }) + } - if !known_deletes.contains(&delete_hash) { - deletes.push(delete); - known_deletes.insert(delete_hash); - } - } + // we want higher scores first + out.sort_by(|a, b| b.partial_cmp(a).expect("candidate scores are never NaN.")); + Some(out) + } + + pub fn suggest(&self, tokens: &[Token], options: &SpellcheckerOptions) -> Vec { + let mut suggestions = Vec::new(); + + for token in tokens { + if let Some(candidates) = self.lookup(token, options) { + // TODO: disallow empty / properly treat empty + suggestions.push(Suggestion { + source: "SPELLCHECK/SINGLE".into(), + message: "Possibly misspelled word.".into(), + start: token.char_span.0, + end: token.char_span.1, + replacements: candidates + .into_iter() + .map(|x| x.term.to_owned()) + .take(options.n_suggestions) + .collect(), + }) } } - candidates.sort(); - Some(candidates) + suggestions } } diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 4102be4..4e8534a 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -17,11 +17,6 @@ pub(crate) struct TaggerLangOptions { pub always_add_lower_tags: bool, /// Used part-of-speech tags which are not in the tagger dictionary. pub extra_tags: Vec, - /// Variants of the language (e.g. "en_US", "en_GB") to consider for spellchecking. - pub variants: Vec, - /// The language code in two-letter format. Set automatically by the compile module. - #[serde(skip)] - pub lang_code: String, } #[derive(Serialize, Deserialize)] @@ -47,15 +42,15 @@ impl From for TaggerFields { let key: Vec = word.as_bytes().iter().chain(once(&i)).copied().collect(); let pos_bytes = pos_id.0.to_be_bytes(); - let inflect_bytes = inflect_id.to_bytes(); + let inflect_bytes = inflect_id.0.to_be_bytes(); let value = u64::from_be_bytes([ inflect_bytes[0], inflect_bytes[1], inflect_bytes[2], inflect_bytes[3], - inflect_bytes[4], - inflect_bytes[5], + 0, + 0, pos_bytes[0], pos_bytes[1], ]); @@ -69,7 +64,7 @@ impl From for TaggerFields { let mut word_store_items: Vec<_> = tagger .word_store .iter() - .map(|(key, value)| (key.clone(), value.to_u64())) + .map(|(key, value)| (key.clone(), value.0 as u64)) .collect(); word_store_items.sort_by(|(a, _), (b, _)| a.cmp(b)); @@ -101,7 +96,7 @@ impl From for Tagger { .into_str_vec() .unwrap() .into_iter() - .map(|(key, value)| (key, WordIdInt::from_u64(value))) + .map(|(key, value)| (key, WordIdInt(value as u32))) .collect(); let mut tags = DefaultHashMap::new(); @@ -115,14 +110,12 @@ impl From for Tagger { let word_id = *word_store.get_by_left(word).unwrap(); let value_bytes = value.to_be_bytes(); - let inflection_id = WordIdInt::from_bytes([ + let inflection_id = WordIdInt(u32::from_be_bytes([ value_bytes[0], value_bytes[1], value_bytes[2], value_bytes[3], - value_bytes[4], - value_bytes[5], - ]); + ])); let pos_id = PosIdInt(u16::from_be_bytes([value_bytes[6], value_bytes[7]])); let group = groups.entry(inflection_id).or_insert_with(Vec::new); @@ -201,6 +194,10 @@ impl Tagger { tags } + pub(crate) fn lang_options(&self) -> &TaggerLangOptions { + &self.lang_options + } + #[allow(dead_code)] // used by compile module pub(crate) fn tag_store(&self) -> &BiMap { &self.tag_store diff --git a/nlprule/src/types.rs b/nlprule/src/types.rs index 7c94ca6..9d11260 100644 --- a/nlprule/src/types.rs +++ b/nlprule/src/types.rs @@ -16,39 +16,7 @@ pub(crate) type DefaultHasher = hash_map::DefaultHasher; #[derive( Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd, Default, )] -pub(crate) struct WordIdInt(u32, u8, u8); - -impl WordIdInt { - pub fn new(index: u32, freq: u8, variants: u8) -> Self { - WordIdInt(index, freq, variants) - } - - pub fn freq(&self) -> u8 { - (self.0 & 255) as u8 - } - - pub fn to_u64(&self) -> u64 { - let b = self.to_bytes(); - u64::from_be_bytes([b[0], b[1], b[2], b[3], b[4], b[5], self.1, self.2]) - } - - pub fn from_u64(value: u64) -> Self { - let b = value.to_be_bytes(); - Self::from_bytes([b[0], b[1], b[2], b[3], b[4], b[5]]) - } - - pub fn to_bytes(&self) -> [u8; 6] { - let b = self.0.to_be_bytes(); - [b[0], b[1], b[2], b[3], self.1, self.2] - } - - pub fn from_bytes(value: [u8; 6]) -> Self { - let index = u32::from_be_bytes([value[0], value[1], value[2], value[3]]); - let freq = value[4]; - let variants = value[5]; - WordIdInt(index, freq, variants) - } -} +pub(crate) struct WordIdInt(pub u32); #[derive(Debug, Copy, Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)] #[serde(transparent)] diff --git a/nlprule/tests/tests.rs b/nlprule/tests/tests.rs index 8f0cbdc..051cdd8 100644 --- a/nlprule/tests/tests.rs +++ b/nlprule/tests/tests.rs @@ -1,7 +1,9 @@ use std::{convert::TryInto, sync::Arc}; use lazy_static::lazy_static; -use nlprule::{rule::id::Category, Rules, Tokenizer}; +use nlprule::{ + rule::id::Category, rules::RulesOptions, spellcheck::SpellcheckerOptions, Rules, Tokenizer, +}; use quickcheck_macros::quickcheck; const TOKENIZER_PATH: &str = "../storage/en_tokenizer.bin"; @@ -51,3 +53,20 @@ fn rules_can_be_disabled_enabled() { // enabled now assert!(!rules.suggest("I can not go").is_empty()); } + +#[test] +fn spellchecker_works() { + let rules = Rules::new_with_options( + RULES_PATH, + TOKENIZER.clone(), + RulesOptions { + spellchecker_options: SpellcheckerOptions { + variant: Some("en_GB".into()), + ..SpellcheckerOptions::default() + }, + }, + ) + .unwrap(); + + println!("{:#?}", rules.suggest("mom")); +} From f086fc5a8c0125db671695c7ac980349915d049d Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Sat, 6 Mar 2021 09:10:51 +0100 Subject: [PATCH 06/16] add tests for spellint, add hashset to check if word is valid --- bench/__init__.py | 7 ++- nlprule/src/compile/impls.rs | 2 +- nlprule/src/spellcheck/mod.rs | 103 +++++++++++++++++++++++++--------- nlprule/src/tokenizer/tag.rs | 4 -- python/Cargo.toml | 1 + python/src/lib.rs | 41 +++++++++++--- 6 files changed, 115 insertions(+), 43 deletions(-) diff --git a/bench/__init__.py b/bench/__init__.py index 92c668b..eebff50 100644 --- a/bench/__init__.py +++ b/bench/__init__.py @@ -34,7 +34,6 @@ def __init__(self, lang_code: str, ids: Set[str]): lt_code, remote_server="http://localhost:8081/" ) self.tool.disabled_rules = { - "MORFOLOGIK_RULE_EN_US", "GERMAN_SPELLER_RULE", "COMMA_PARENTHESIS_WHITESPACE", "DOUBLE_PUNCTUATION", @@ -116,7 +115,11 @@ def suggest(self, sentence: str) -> Set[Suggestion]: class NLPRule: def __init__(self, lang_code: str): self.tokenizer = nlprule.Tokenizer(f"storage/{lang_code}_tokenizer.bin") - self.rules = nlprule.Rules(f"storage/{lang_code}_rules.bin", self.tokenizer) + self.rules = nlprule.Rules( + f"storage/{lang_code}_rules.bin", + self.tokenizer, + spellchecker_options={"variant": "en_US", "max_distance": 0, "prefix": 0,}, + ) def suggest(self, sentence: str) -> Set[Suggestion]: suggestions = { diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 94f87b4..abc3183 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -83,7 +83,7 @@ impl Spellchecker { max_freq, lang_options, used_variant: Arc::new(AtomicUsize::new(usize::MAX)), - used_fst: Arc::new(Vec::new().into()), + used: Arc::new(Default::default()), }) } } diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index 34b9c07..24a65af 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -1,5 +1,6 @@ use std::{ - ops::Deref, + collections::HashSet, + ops::{Deref, DerefMut}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, RwLock, RwLockReadGuard, @@ -13,37 +14,76 @@ use crate::types::*; mod levenshtein; -#[derive(Debug, Clone, Default, Copy)] -pub(crate) struct SpellInt(u64); +mod spell_int { + #[derive(Debug, Clone, Default, Copy)] + pub(crate) struct SpellInt(pub(super) u64); -impl SpellInt { - pub fn as_u64(&self) -> u64 { - self.0 + type FreqType = u8; + + const fn freq_size() -> usize { + std::mem::size_of::() * 8 } - pub fn update_freq(&mut self, freq: usize) { - assert!(freq < u32::MAX as usize); + impl SpellInt { + pub fn as_u64(&self) -> u64 { + self.0 + } - // erase previous frequency - self.0 = self.0 & (u64::MAX - u32::MAX as u64); - // set new frequency - self.0 |= freq as u64; - } + pub fn update_freq(&mut self, freq: usize) { + assert!(freq < FreqType::MAX as usize); - pub fn add_variant(&mut self, index: usize) { - assert!(index < 32); - self.0 |= 1 << (32 + index); - } + // erase previous frequency + self.0 = self.0 & (u64::MAX - FreqType::MAX as u64); + // set new frequency + self.0 |= freq as u64; + } + + pub fn add_variant(&mut self, index: usize) { + assert!(index < 64 - freq_size()); + self.0 |= 1 << (freq_size() + index); + } + + pub fn contains_variant(&self, index: usize) -> bool { + (self.0 >> (freq_size() + index)) & 1 == 1 + } - pub fn contains_variant(&self, index: usize) -> bool { - (self.0 >> (32 + index)) & 1 == 1 + pub fn freq(&self) -> usize { + (self.0 & FreqType::MAX as u64) as usize + } } - pub fn freq(&self) -> usize { - (self.0 & u32::MAX as u64) as usize + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn can_encode_freq() { + let mut int = SpellInt::default(); + int.update_freq(100); + int.add_variant(1); + int.add_variant(10); + + assert!(int.freq() == 100); + } + + #[test] + fn can_encode_variants() { + let mut int = SpellInt::default(); + int.update_freq(100); + int.add_variant(1); + int.add_variant(10); + int.update_freq(10); + + assert!(int.contains_variant(1)); + assert!(int.contains_variant(10)); + assert!(!int.contains_variant(2)); + assert!(int.freq() == 10); + } } } +pub(crate) use spell_int::SpellInt; + #[derive(Debug, Clone, Default, PartialEq, PartialOrd)] struct Candidate { pub score: f32, @@ -51,6 +91,7 @@ struct Candidate { } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(default)] /// TODO pub struct SpellcheckerOptions { pub variant: Option, @@ -84,14 +125,14 @@ pub(crate) struct Spellchecker { pub(crate) max_freq: usize, pub(crate) lang_options: SpellcheckerLangOptions, pub(crate) used_variant: Arc, - pub(crate) used_fst: Arc>>, + pub(crate) used: Arc, HashSet)>>, } impl Spellchecker { fn update_used_fst( &self, options: &SpellcheckerOptions, - ) -> Option>> { + ) -> Option, HashSet)>> { let variant = if let Some(variant) = options.variant.as_ref() { variant.as_str() } else { @@ -105,36 +146,42 @@ impl Spellchecker { if self.used_variant.swap(variant_index, Ordering::Relaxed) != variant_index { let mut used_fst_builder = MapBuilder::memory(); + let mut set = DefaultHashSet::new(); let fst = Map::new(&self.fst).expect("serialized fst must be valid."); let mut stream = fst.into_stream(); while let Some((k, v)) = stream.next() { if SpellInt(v).contains_variant(variant_index) { + set.insert( + String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."), + ); used_fst_builder .insert(k, v) .expect("fst stream returns values in lexicographic order."); } } - let mut guard = self.used_fst.write(); - let used_fst = guard.as_deref_mut().expect("lock must not be poisoned."); + let mut guard = self.used.write().expect("lock must not be poisoned."); + let (used_fst, used_set) = guard.deref_mut(); *used_fst = used_fst_builder .into_inner() .expect("subset of valid fst must be valid."); + *used_set = set; } - Some(self.used_fst.read().expect("lock must not be poisoned")) + Some(self.used.read().expect("lock must not be poisoned")) } fn lookup(&self, token: &Token, options: &SpellcheckerOptions) -> Option> { let guard = self.update_used_fst(options)?; - let used_fst = Map::new(guard.deref()).expect("used fst must be valid."); + let (used_fst, used_set) = guard.deref(); + let used_fst = Map::new(used_fst).expect("used fst must be valid."); let text = token.word.text.as_ref(); // no text => nothing to correct, only the case for special tokens (e.g. SENT_START) - if text.is_empty() { + if text.is_empty() || used_set.contains(text) { return None; } diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 4e8534a..2693541 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -194,10 +194,6 @@ impl Tagger { tags } - pub(crate) fn lang_options(&self) -> &TaggerLangOptions { - &self.lang_options - } - #[allow(dead_code)] // used by compile module pub(crate) fn tag_store(&self) -> &BiMap { &self.tag_store diff --git a/python/Cargo.toml b/python/Cargo.toml index 6979312..539f3d9 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -15,6 +15,7 @@ parking_lot = { version = "0.11", features = ["serde"] } reqwest = { version = "0.11", default_features = false, features = ["blocking", "rustls-tls"]} flate2 = "1" directories = "3" +pythonize = "0.13" syn = "=1.0.57" # workaround for "could not find `export` in `syn`" by enum_dispatch nlprule = { path = "../nlprule" } # BUILD_BINDINGS_COMMENT # nlprule = { package = "nlprule-core", path = "../nlprule" } # BUILD_BINDINGS_UNCOMMENT diff --git a/python/src/lib.rs b/python/src/lib.rs index 13eafb0..a334670 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,7 +1,7 @@ use flate2::read::GzDecoder; use nlprule::{ rule::{id::Selector, Example, Rule}, - rules::{apply_suggestions, Rules}, + rules::{apply_suggestions, Rules, RulesOptions}, tokenizer::tag::Tagger, tokenizer::Tokenizer, types::*, @@ -9,9 +9,10 @@ use nlprule::{ use parking_lot::{ MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, }; -use pyo3::prelude::*; -use pyo3::types::PyString; use pyo3::{exceptions::PyValueError, types::PyBytes}; +use pyo3::{prelude::*, types::PyDict}; +use pyo3::{types::PyString, ToPyObject}; +use pythonize::depythonize; use std::{ convert::TryFrom, fs, @@ -560,12 +561,24 @@ struct PyRules { #[pymethods] impl PyRules { - #[text_signature = "(code, tokenizer, sentence_splitter=None)"] + #[text_signature = "(code, tokenizer, sentence_splitter=None, **kwargs)"] + #[args(kwargs = "**")] #[staticmethod] - fn load(lang_code: &str, tokenizer: &PyTokenizer) -> PyResult { + fn load( + py: Python, + lang_code: &str, + tokenizer: &PyTokenizer, + kwargs: Option<&PyDict>, + ) -> PyResult { let bytes = get_resource(lang_code, "rules.bin.gz")?; - let rules = Rules::from_reader(bytes, tokenizer.tokenizer().clone()) + let options = if let Some(options) = kwargs { + depythonize(options.to_object(py).as_ref(py))? + } else { + RulesOptions::default() + }; + + let rules = Rules::from_reader_with_options(bytes, tokenizer.tokenizer().clone(), options) .map_err(|x| PyValueError::new_err(format!("{}", x)))?; Ok(PyRules { rules: Arc::from(RwLock::from(rules)), @@ -573,15 +586,27 @@ impl PyRules { } #[new] - fn new(path: Option<&str>, tokenizer: Option<&PyTokenizer>) -> PyResult { + #[args(kwargs = "**")] + fn new( + py: Python, + path: Option<&str>, + tokenizer: Option<&PyTokenizer>, + kwargs: Option<&PyDict>, + ) -> PyResult { let tokenizer = if let Some(tokenizer) = tokenizer { tokenizer.tokenizer().clone() } else { PyTokenizer::default().tokenizer().clone() }; + let options = if let Some(options) = kwargs { + depythonize(options.to_object(py).as_ref(py))? + } else { + RulesOptions::default() + }; + let rules = if let Some(path) = path { - Rules::new(path, tokenizer) + Rules::new_with_options(path, tokenizer, options) .map_err(|x| PyValueError::new_err(format!("error creating Rules: {}", x)))? } else { Rules::default() From 0c66c780f40510e147c364f5cee213a723787725 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Sat, 6 Mar 2021 18:27:45 +0100 Subject: [PATCH 07/16] add python bindings --- nlprule/src/compile/impls.rs | 28 ++-- nlprule/src/compile/mod.rs | 4 +- nlprule/src/compile/utils.rs | 6 +- nlprule/src/lib.rs | 2 + nlprule/src/rules.rs | 67 ++------- nlprule/src/spellcheck/mod.rs | 276 ++++++++++++++++++++++------------ nlprule/tests/tests.rs | 21 +-- python/src/lib.rs | 272 +++++++++++++++++++++++++++------ python/test.py | 20 +++ 9 files changed, 464 insertions(+), 232 deletions(-) diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index abc3183..758b330 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -9,7 +9,7 @@ use std::{ hash::{Hash, Hasher}, io::{self, BufRead, BufReader}, path::Path, - sync::{atomic::AtomicUsize, Arc}, + sync::Arc, }; use crate::{ @@ -22,8 +22,8 @@ use crate::{ id::Category, DisambiguationRule, MatchGraph, Rule, }, - rules::{Rules, RulesLangOptions, RulesOptions}, - spellcheck::{SpellInt, Spellchecker, SpellcheckerLangOptions}, + rules::{Rules, RulesLangOptions}, + spellcheck::{Spell, SpellInt, SpellLangOptions}, tokenizer::{ chunk, multiword::{MultiwordTagger, MultiwordTaggerFields}, @@ -36,16 +36,19 @@ use crate::{ use super::{parse_structure::BuildInfo, Error}; -impl Spellchecker { +impl Spell { pub(in crate::compile) fn from_dumps>( spell_dir_path: S, - lang_options: SpellcheckerLangOptions, + lang_options: SpellLangOptions, ) -> io::Result { let mut words: HashMap = DefaultHashMap::new(); let mut max_freq = 0; for (i, variant) in lang_options.variants.iter().enumerate() { - let spell_path = spell_dir_path.as_ref().join(variant).with_extension("dump"); + let spell_path = spell_dir_path + .as_ref() + .join(variant.as_str()) + .with_extension("dump"); let reader = BufReader::new(File::open(spell_path)?); for line in reader.lines() { @@ -78,13 +81,7 @@ impl Spellchecker { let map = fst::Map::from_iter(words.into_iter()).expect("words are lexicographically sorted."); - Ok(Spellchecker { - fst: map.into_fst().to_vec(), - max_freq, - lang_options, - used_variant: Arc::new(AtomicUsize::new(usize::MAX)), - used: Arc::new(Default::default()), - }) + Ok(Spell::new(map.into_fst().to_vec(), max_freq, lang_options)) } } @@ -318,7 +315,7 @@ impl Rules { pub(in crate::compile) fn from_xml>( path: P, build_info: &mut BuildInfo, - spellchecker: Spellchecker, + spell: Spell, tokenizer: Arc, options: RulesLangOptions, ) -> Self { @@ -414,9 +411,8 @@ impl Rules { Rules { rules, - spellchecker, + spell, tokenizer, - options: RulesOptions::default(), } } } diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index 365aa78..123776d 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -14,7 +14,7 @@ use std::{ use crate::{ rules::Rules, - spellcheck::Spellchecker, + spellcheck::Spell, tokenizer::{chunk::Chunker, multiword::MultiwordTagger, tag::Tagger, Tokenizer}, types::*, }; @@ -189,7 +189,7 @@ pub fn compile( None }; - let spellchecker = Spellchecker::from_dumps(paths.spell_dir_path, spellchecker_lang_options)?; + let spellchecker = Spell::from_dumps(paths.spell_dir_path, spellchecker_lang_options)?; info!("Creating tokenizer."); let tokenizer = Tokenizer::from_xml( diff --git a/nlprule/src/compile/utils.rs b/nlprule/src/compile/utils.rs index d4d5acf..3a9d8aa 100644 --- a/nlprule/src/compile/utils.rs +++ b/nlprule/src/compile/utils.rs @@ -1,5 +1,5 @@ use crate::{ - rules::RulesLangOptions, spellcheck::SpellcheckerLangOptions, tokenizer::TokenizerLangOptions, + rules::RulesLangOptions, spellcheck::SpellLangOptions, tokenizer::TokenizerLangOptions, }; use crate::{tokenizer::tag::TaggerLangOptions, types::*}; use lazy_static::lazy_static; @@ -38,7 +38,7 @@ lazy_static! { } lazy_static! { - static ref SPELLCHECKER_LANG_OPTIONS: DefaultHashMap = { + static ref SPELLCHECKER_LANG_OPTIONS: DefaultHashMap = { serde_json::from_slice(include_bytes!(concat!( env!("OUT_DIR"), "/", @@ -64,7 +64,7 @@ pub(crate) fn tagger_lang_options(lang_code: &str) -> Option } /// Gets the spellchecker language options for the language code -pub(crate) fn spellchecker_lang_options(lang_code: &str) -> Option { +pub(crate) fn spellchecker_lang_options(lang_code: &str) -> Option { SPELLCHECKER_LANG_OPTIONS.get(lang_code).cloned() } diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 10b443d..3cb4d6a 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -78,6 +78,8 @@ pub enum Error { Io(#[from] io::Error), #[error("deserialization error: {0}")] Deserialization(#[from] bincode::Error), + #[error("unknown language variant: {0}")] + UnknownVariant(String), } /// Gets the canonical filename for the tokenizer binary for a language code in ISO 639-1 (two-letter) format. diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index 9c20000..b57a056 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -2,10 +2,7 @@ use crate::{rule::id::Selector, tokenizer::Tokenizer}; use crate::{rule::Rule, Error}; -use crate::{ - spellcheck::Spellchecker, spellcheck::SpellcheckerOptions, types::*, - utils::parallelism::MaybeParallelRefIterator, -}; +use crate::{spellcheck::Spell, types::*, utils::parallelism::MaybeParallelRefIterator}; use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ @@ -14,13 +11,6 @@ use std::{ sync::Arc, }; -/// Options for a rule set. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] -pub struct RulesOptions { - /// TODO - pub spellchecker_options: SpellcheckerOptions, -} - /// Language-dependent options for a rule set. #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct RulesLangOptions { @@ -47,14 +37,14 @@ impl Default for RulesLangOptions { #[derive(Serialize, Deserialize, Default)] struct RulesFields { pub(crate) rules: Vec, - pub(crate) spellchecker: Spellchecker, + pub(crate) spell: Spell, } impl From for RulesFields { fn from(rules: Rules) -> Self { RulesFields { rules: rules.rules, - spellchecker: rules.spellchecker, + spell: rules.spell, } } } @@ -63,9 +53,8 @@ impl From for RulesFields { #[derive(Clone, Default, Serialize, Deserialize)] pub struct Rules { pub(crate) rules: Vec, - pub(crate) spellchecker: Spellchecker, + pub(crate) spell: Spell, pub(crate) tokenizer: Arc, - pub(crate) options: RulesOptions, } impl Rules { @@ -77,55 +66,34 @@ impl Rules { Ok(()) } - /// TODO - pub fn from_reader_with_options( - reader: R, - tokenizer: Arc, - options: RulesOptions, - ) -> Result { + /// Creates a new rules set from a reader. + pub fn from_reader(reader: R, tokenizer: Arc) -> Result { let fields: RulesFields = bincode::deserialize_from(reader)?; let rules = Rules { rules: fields.rules, - options, - spellchecker: fields.spellchecker, + spell: fields.spell, tokenizer, }; Ok(rules) } - /// Creates a new rule set with options. See [new][Rules::new]. - pub fn new_with_options>( - p: P, - tokenizer: Arc, - options: RulesOptions, - ) -> Result { - let reader = BufReader::new(File::open(p.as_ref())?); - - Self::from_reader_with_options(reader, tokenizer, options) - } - - /// Creates a new rules set from a reader. - pub fn from_reader(reader: R, tokenizer: Arc) -> Result { - Self::from_reader_with_options(reader, tokenizer, RulesOptions::default()) - } - /// Creates a new rule set from a path to a binary. /// /// # Errors /// - If the file can not be opened. /// - If the file content can not be deserialized to a rules set. pub fn new>(p: P, tokenizer: Arc) -> Result { - Self::new_with_options(p, tokenizer, RulesOptions::default()) + let reader = BufReader::new(File::open(p.as_ref())?); + + Self::from_reader(reader, tokenizer) } - /// Gets the options of this rule set. - pub fn options(&self) -> &RulesOptions { - &self.options + pub fn spell(&self) -> &Spell { + &self.spell } - /// Sets the options of this rule set. - pub fn mut_options(&mut self) -> &mut RulesOptions { - &mut self.options + pub fn spell_mut(&mut self) -> &mut Spell { + &mut self.spell } /// All rules ordered by priority. @@ -177,12 +145,7 @@ impl Rules { .flatten() .collect(); - output.extend( - self.spellchecker - .suggest(tokens, &self.options.spellchecker_options) - .into_iter() - .map(|x| (0, x)), - ); + output.extend(self.spell.suggest(tokens).into_iter().map(|x| (0, x))); output.sort_by(|(ia, a), (ib, b)| a.start.cmp(&b.start).then_with(|| ib.cmp(ia))); diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index 24a65af..d13ae0b 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -1,16 +1,13 @@ use std::{ - collections::HashSet, + cmp::Ordering, + collections::{BinaryHeap, HashSet}, ops::{Deref, DerefMut}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, RwLock, RwLockReadGuard, - }, }; use fst::{IntoStreamer, Map, MapBuilder, Streamer}; use serde::{Deserialize, Serialize}; -use crate::types::*; +use crate::{types::*, Error}; mod levenshtein; @@ -84,149 +81,240 @@ mod spell_int { pub(crate) use spell_int::SpellInt; -#[derive(Debug, Clone, Default, PartialEq, PartialOrd)] -struct Candidate { - pub score: f32, - pub term: String, +#[derive(Debug, Clone, Default, PartialEq)] +pub struct Candidate { + score: f32, + distance: usize, + freq: usize, + term: String, +} +impl Eq for Candidate {} +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + // higher score => lower order such that sorting puts highest scores first + other.score.partial_cmp(&self.score) + } +} +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).expect("scores are never NaN") + } +} + +impl Candidate { + pub fn score(&self) -> f32 { + self.score + } + + pub fn freq(&self) -> usize { + self.freq + } + + pub fn distance(&self) -> usize { + self.distance + } + + pub fn term(&self) -> &str { + self.term.as_str() + } } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(default)] /// TODO -pub struct SpellcheckerOptions { - pub variant: Option, +pub struct SpellOptions { + pub variant: Option, pub max_distance: usize, - pub prefix: usize, - pub frequency_weight: f32, - pub n_suggestions: usize, + pub prefix_length: usize, + pub freq_weight: f32, + pub top_n: usize, + pub whitelist: HashSet, +} + +pub struct SpellOptionsGuard<'a> { + spell: &'a mut Spell, +} + +impl<'a> Deref for SpellOptionsGuard<'a> { + type Target = SpellOptions; + + fn deref(&self) -> &Self::Target { + &self.spell.options + } +} + +impl<'a> DerefMut for SpellOptionsGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.spell.options + } +} + +impl<'a> Drop for SpellOptionsGuard<'a> { + fn drop(&mut self) { + self.spell.ingest_options() + } } #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] -pub(crate) struct SpellcheckerLangOptions { +pub(crate) struct SpellLangOptions { /// Variants of the language (e. g. "en_US", "en_GB") to consider for spellchecking. - pub variants: Vec, + pub variants: Vec, } -impl Default for SpellcheckerOptions { +impl Default for SpellOptions { fn default() -> Self { - SpellcheckerOptions { + SpellOptions { variant: None, max_distance: 2, - prefix: 2, - frequency_weight: 2., - n_suggestions: 10, + prefix_length: 2, + freq_weight: 2., + top_n: 10, + whitelist: HashSet::new(), } } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +#[serde(transparent)] +pub struct Variant(String); + +impl Variant { + pub fn as_str(&self) -> &str { + self.0.as_str() + } +} + #[derive(Debug, Clone, Default, Serialize, Deserialize)] -pub(crate) struct Spellchecker { - pub(crate) fst: Vec, - pub(crate) max_freq: usize, - pub(crate) lang_options: SpellcheckerLangOptions, - pub(crate) used_variant: Arc, - pub(crate) used: Arc, HashSet)>>, +pub struct Spell { + fst: Vec, + max_freq: usize, + lang_options: SpellLangOptions, + options: SpellOptions, + used_variant: Option, + used_fst: Vec, + used_set: HashSet, } -impl Spellchecker { - fn update_used_fst( - &self, - options: &SpellcheckerOptions, - ) -> Option, HashSet)>> { - let variant = if let Some(variant) = options.variant.as_ref() { - variant.as_str() - } else { - return None; +impl Spell { + pub(crate) fn new(fst: Vec, max_freq: usize, lang_options: SpellLangOptions) -> Self { + let mut spell = Spell { + fst, + max_freq, + lang_options, + options: SpellOptions::default(), + ..Default::default() }; - let variant_index = self - .lang_options + spell.ingest_options(); + spell + } + + pub fn options(&self) -> &SpellOptions { + &self.options + } + + pub fn options_mut(&mut self) -> SpellOptionsGuard { + SpellOptionsGuard { spell: self } + } + + pub fn variants(&self) -> &[Variant] { + self.lang_options.variants.as_slice() + } + + pub fn variant(&self, variant: &str) -> Result { + self.lang_options .variants .iter() - .position(|x| x == variant)?; - - if self.used_variant.swap(variant_index, Ordering::Relaxed) != variant_index { - let mut used_fst_builder = MapBuilder::memory(); - let mut set = DefaultHashSet::new(); - - let fst = Map::new(&self.fst).expect("serialized fst must be valid."); - let mut stream = fst.into_stream(); - - while let Some((k, v)) = stream.next() { - if SpellInt(v).contains_variant(variant_index) { - set.insert( - String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."), - ); - used_fst_builder - .insert(k, v) - .expect("fst stream returns values in lexicographic order."); - } - } + .find(|x| x.as_str() == variant) + .cloned() + .ok_or_else(|| Error::UnknownVariant(variant.to_owned())) + } + + pub(crate) fn ingest_options(&mut self) { + if self.used_variant == self.options.variant { + return; + } + + let variant = if let Some(variant) = self.options.variant.as_ref() { + variant + } else { + self.used_variant = None; + self.used_fst = Vec::new(); + self.used_set = DefaultHashSet::new(); + return; + }; + + let mut used_fst_builder = MapBuilder::memory(); + let mut set = DefaultHashSet::new(); - let mut guard = self.used.write().expect("lock must not be poisoned."); - let (used_fst, used_set) = guard.deref_mut(); + let fst = Map::new(&self.fst).expect("serialized fst must be valid."); + let mut stream = fst.into_stream(); - *used_fst = used_fst_builder - .into_inner() - .expect("subset of valid fst must be valid."); - *used_set = set; + let variant_index = self + .variants() + .iter() + .position(|x| x == variant) + .expect("only valid variants are created."); + + while let Some((k, v)) = stream.next() { + if SpellInt(v).contains_variant(variant_index) { + set.insert(String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8.")); + used_fst_builder + .insert(k, v) + .expect("fst stream returns values in lexicographic order."); + } } - Some(self.used.read().expect("lock must not be poisoned")) + self.used_variant = Some(variant.clone()); + self.used_fst = used_fst_builder + .into_inner() + .expect("subset of valid fst must be valid."); + self.used_set = set; } - fn lookup(&self, token: &Token, options: &SpellcheckerOptions) -> Option> { - let guard = self.update_used_fst(options)?; - let (used_fst, used_set) = guard.deref(); - let used_fst = Map::new(used_fst).expect("used fst must be valid."); - - let text = token.word.text.as_ref(); - // no text => nothing to correct, only the case for special tokens (e.g. SENT_START) - if text.is_empty() || used_set.contains(text) { + pub fn search(&self, word: &str) -> Option> { + if self.used_variant.is_none() || word.is_empty() || self.used_set.contains(word) { return None; } - let query = levenshtein::Levenshtein::new(text, options.max_distance, 2); + let used_fst = Map::new(self.used_fst.as_slice()).expect("used fst must be valid."); + let query = levenshtein::Levenshtein::new(word, self.options.max_distance, 2); - let mut out = Vec::new(); + let mut out = BinaryHeap::with_capacity(self.options.top_n); let mut stream = used_fst.search_with_state(query).into_stream(); while let Some((k, v, s)) = stream.next() { let state = s.expect("matching levenshtein state is always `Some`."); - if state.dist() == 0 { - return None; - } + assert!(state.dist() > 0); let id = SpellInt(v); - let string = String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."); + let term = String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."); out.push(Candidate { - score: (options.max_distance - state.dist()) as f32 - + id.freq() as f32 / self.max_freq as f32 * options.frequency_weight, - term: string, - }) + distance: state.dist(), + freq: id.freq(), + term, + score: (self.options.max_distance - state.dist()) as f32 + + id.freq() as f32 / self.max_freq as f32 * self.options.freq_weight, + }); + if out.len() > self.options.top_n { + out.pop(); + } } - // we want higher scores first - out.sort_by(|a, b| b.partial_cmp(a).expect("candidate scores are never NaN.")); - Some(out) + Some(out.into_sorted_vec()) } - pub fn suggest(&self, tokens: &[Token], options: &SpellcheckerOptions) -> Vec { + pub fn suggest(&self, tokens: &[Token]) -> Vec { let mut suggestions = Vec::new(); for token in tokens { - if let Some(candidates) = self.lookup(token, options) { - // TODO: disallow empty / properly treat empty + if let Some(candidates) = self.search(token.word.text.as_ref()) { suggestions.push(Suggestion { source: "SPELLCHECK/SINGLE".into(), message: "Possibly misspelled word.".into(), start: token.char_span.0, end: token.char_span.1, - replacements: candidates - .into_iter() - .map(|x| x.term.to_owned()) - .take(options.n_suggestions) - .collect(), + replacements: candidates.into_iter().map(|x| x.term.to_owned()).collect(), }) } } diff --git a/nlprule/tests/tests.rs b/nlprule/tests/tests.rs index 051cdd8..f71789f 100644 --- a/nlprule/tests/tests.rs +++ b/nlprule/tests/tests.rs @@ -1,9 +1,7 @@ use std::{convert::TryInto, sync::Arc}; use lazy_static::lazy_static; -use nlprule::{ - rule::id::Category, rules::RulesOptions, spellcheck::SpellcheckerOptions, Rules, Tokenizer, -}; +use nlprule::{rule::id::Category, Error, Rules, Tokenizer}; use quickcheck_macros::quickcheck; const TOKENIZER_PATH: &str = "../storage/en_tokenizer.bin"; @@ -55,18 +53,9 @@ fn rules_can_be_disabled_enabled() { } #[test] -fn spellchecker_works() { - let rules = Rules::new_with_options( - RULES_PATH, - TOKENIZER.clone(), - RulesOptions { - spellchecker_options: SpellcheckerOptions { - variant: Some("en_GB".into()), - ..SpellcheckerOptions::default() - }, - }, - ) - .unwrap(); +fn spellchecker_works() -> Result<(), Error> { + let mut rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); + rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB")?); - println!("{:#?}", rules.suggest("mom")); + Ok(()) } diff --git a/python/src/lib.rs b/python/src/lib.rs index a334670..d3321bc 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,7 +1,8 @@ use flate2::read::GzDecoder; use nlprule::{ rule::{id::Selector, Example, Rule}, - rules::{apply_suggestions, Rules, RulesOptions}, + rules::{apply_suggestions, Rules}, + spellcheck::{Candidate, Spell}, tokenizer::tag::Tagger, tokenizer::Tokenizer, types::*, @@ -10,10 +11,15 @@ use parking_lot::{ MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard, }; use pyo3::{exceptions::PyValueError, types::PyBytes}; -use pyo3::{prelude::*, types::PyDict}; +use pyo3::{ + prelude::*, + types::{PyDict, PyFrozenSet}, + wrap_pymodule, +}; use pyo3::{types::PyString, ToPyObject}; use pythonize::depythonize; use std::{ + collections::HashSet, convert::TryFrom, fs, io::{Cursor, Read}, @@ -21,6 +27,10 @@ use std::{ sync::Arc, }; +fn err(error: nlprule::Error) -> PyErr { + PyValueError::new_err(format!("{}", error)) +} + fn get_resource(lang_code: &str, name: &str) -> PyResult { let version = env!("CARGO_PKG_VERSION"); let mut cache_path: Option = None; @@ -307,7 +317,7 @@ impl From for PySuggestion { /// When created from a language code, the binary is downloaded from the internet the first time. /// Then it is stored at your cache and loaded from there. #[pyclass(name = "Tokenizer", module = "nlprule")] -#[text_signature = "(path, sentence_splitter=None)"] +#[text_signature = "(path)"] #[derive(Default)] pub struct PyTokenizer { tokenizer: Arc, @@ -321,13 +331,12 @@ impl PyTokenizer { #[pymethods] impl PyTokenizer { - #[text_signature = "(code, sentence_splitter=None)"] + #[text_signature = "(code)"] #[staticmethod] fn load(lang_code: &str) -> PyResult { let bytes = get_resource(lang_code, "tokenizer.bin.gz")?; - let tokenizer: Tokenizer = bincode::deserialize_from(bytes) - .map_err(|x| PyValueError::new_err(format!("{}", x)))?; + let tokenizer = Tokenizer::from_reader(bytes).map_err(err)?; Ok(PyTokenizer { tokenizer: Arc::new(tokenizer), }) @@ -336,8 +345,7 @@ impl PyTokenizer { #[new] fn new(path: Option<&str>) -> PyResult { let tokenizer = if let Some(path) = path { - Tokenizer::new(path) - .map_err(|x| PyValueError::new_err(format!("error creating Tokenizer: {}", x)))? + Tokenizer::new(path).map_err(err)? } else { Tokenizer::default() }; @@ -479,7 +487,7 @@ impl PyRule { RwLockWriteGuard::map(self.rules.write(), |x| &mut x.rules_mut()[self.index]) } - fn from_rule(index: usize, rules: Arc>) -> Self { + fn from_rules(index: usize, rules: Arc>) -> Self { PyRule { rules, index } } } @@ -541,6 +549,181 @@ impl PyRule { } } +#[pyclass(name = "SpellOptions", module = "nlprule.spell")] +struct PySpellOptions { + rules: Arc>, +} + +impl PySpellOptions { + fn spell(&self) -> MappedRwLockReadGuard<'_, Spell> { + RwLockReadGuard::map(self.rules.read(), |x| x.spell()) + } + + fn spell_mut(&self) -> MappedRwLockWriteGuard<'_, Spell> { + RwLockWriteGuard::map(self.rules.write(), |x| x.spell_mut()) + } +} + +#[pymethods] +impl PySpellOptions { + #[getter] + fn get_variant(&self) -> Option { + self.spell() + .options() + .variant + .as_ref() + .map(|x| x.as_str().to_owned()) + } + + #[setter] + fn set_variant(&self, variant: Option<&str>) -> PyResult<()> { + if let Some(variant) = variant { + let mut spell = self.spell_mut(); + let variant = spell.variant(variant).map_err(err)?; + + spell.options_mut().variant = Some(variant); + } else { + self.spell_mut().options_mut().variant = None; + } + + Ok(()) + } + + #[getter] + fn get_max_distance(&self) -> usize { + self.spell().options().max_distance + } + + #[setter] + fn set_max_distance(&self, max_distance: usize) { + self.spell_mut().options_mut().max_distance = max_distance + } + + #[getter] + fn get_prefix_length(&self) -> usize { + self.spell().options().prefix_length + } + + #[setter] + fn set_prefix_length(&self, prefix_length: usize) { + self.spell_mut().options_mut().prefix_length = prefix_length + } + + #[getter] + fn get_freq_weight(&self) -> f32 { + self.spell().options().freq_weight + } + + #[setter] + fn set_freq_weight(&self, freq_weight: f32) { + self.spell_mut().options_mut().freq_weight = freq_weight + } + + #[getter] + fn get_top_n(&self) -> usize { + self.spell().options().top_n + } + + #[setter] + fn set_top_n(&self, top_n: usize) { + self.spell_mut().options_mut().top_n = top_n + } + + #[getter] + fn get_whitelist<'py>(&self, py: Python<'py>) -> PyResult<&'py PyFrozenSet> { + let spell = self.spell(); + let whitelist: Vec<&str> = spell + .options() + .whitelist + .iter() + .map(|x| x.as_str()) + .collect(); + + PyFrozenSet::new(py, &whitelist) + } + + #[setter] + fn set_whitelist(&self, py: Python, whitelist: PyObject) -> PyResult<()> { + let whitelist: PyResult> = whitelist + .as_ref(py) + .iter()? + .map(|x| x.and_then(PyAny::extract::)) + .collect(); + self.spell_mut().options_mut().whitelist = whitelist?; + Ok(()) + } +} + +#[pyclass(name = "Candidate", module = "nlprule.spell")] +struct PyCandidate { + candidate: Candidate, +} + +#[pymethods] +impl PyCandidate { + #[getter] + fn score(&self) -> f32 { + self.candidate.score() + } + + #[getter] + fn distance(&self) -> usize { + self.candidate.distance() + } + + #[getter] + fn freq(&self) -> usize { + self.candidate.freq() + } + + #[getter] + fn term(&self) -> &str { + self.candidate.term() + } +} + +#[pyclass(name = "Spell", module = "nlprule.spell")] +struct PySpell { + rules: Arc>, +} + +#[pymethods] +impl PySpell { + #[getter] + fn variants(&self) -> Vec { + self.rules + .read() + .spell() + .variants() + .iter() + .map(|x| x.as_str().to_owned()) + .collect() + } + + #[getter] + fn get_options(&self) -> PySpellOptions { + PySpellOptions { + rules: self.rules.clone(), + } + } + + #[setter] + fn set_options(&self, py: Python, options: &PyDict) -> PyResult<()> { + let mut guard = self.rules.write(); + *guard.spell_mut().options_mut() = depythonize(options.to_object(py).as_ref(py))?; + Ok(()) + } + + fn search(&self, word: &str) -> Option> { + self.rules.read().spell().search(word).map(|candidates| { + candidates + .into_iter() + .map(|candidate| PyCandidate { candidate }) + .collect::>() + }) + } +} + /// The grammatical rules. /// Can be created from a rules binary: /// ```python @@ -554,60 +737,34 @@ impl PyRule { /// When created from a language code, the binary is downloaded from the internet the first time. /// Then it is stored at your cache and loaded from there. #[pyclass(name = "Rules", module = "nlprule")] -#[text_signature = "(path, tokenizer, sentence_splitter=None)"] +#[text_signature = "(path, tokenizer)"] struct PyRules { rules: Arc>, } #[pymethods] impl PyRules { - #[text_signature = "(code, tokenizer, sentence_splitter=None, **kwargs)"] - #[args(kwargs = "**")] + #[text_signature = "(code, tokenizer)"] #[staticmethod] - fn load( - py: Python, - lang_code: &str, - tokenizer: &PyTokenizer, - kwargs: Option<&PyDict>, - ) -> PyResult { + fn load(lang_code: &str, tokenizer: &PyTokenizer) -> PyResult { let bytes = get_resource(lang_code, "rules.bin.gz")?; - let options = if let Some(options) = kwargs { - depythonize(options.to_object(py).as_ref(py))? - } else { - RulesOptions::default() - }; - - let rules = Rules::from_reader_with_options(bytes, tokenizer.tokenizer().clone(), options) - .map_err(|x| PyValueError::new_err(format!("{}", x)))?; + let rules = Rules::from_reader(bytes, tokenizer.tokenizer().clone()).map_err(err)?; Ok(PyRules { rules: Arc::from(RwLock::from(rules)), }) } #[new] - #[args(kwargs = "**")] - fn new( - py: Python, - path: Option<&str>, - tokenizer: Option<&PyTokenizer>, - kwargs: Option<&PyDict>, - ) -> PyResult { + fn new(path: Option<&str>, tokenizer: Option<&PyTokenizer>) -> PyResult { let tokenizer = if let Some(tokenizer) = tokenizer { tokenizer.tokenizer().clone() } else { PyTokenizer::default().tokenizer().clone() }; - let options = if let Some(options) = kwargs { - depythonize(options.to_object(py).as_ref(py))? - } else { - RulesOptions::default() - }; - let rules = if let Some(path) = path { - Rules::new_with_options(path, tokenizer, options) - .map_err(|x| PyValueError::new_err(format!("error creating Rules: {}", x)))? + Rules::new(path, tokenizer).map_err(err)? } else { Rules::default() }; @@ -617,6 +774,13 @@ impl PyRules { }) } + #[getter] + fn spell(&self) -> PySpell { + PySpell { + rules: self.rules.clone(), + } + } + #[getter] fn rules(&self) -> Vec { self.rules @@ -624,14 +788,14 @@ impl PyRules { .rules() .iter() .enumerate() - .map(|(i, _)| PyRule::from_rule(i, self.rules.clone())) + .map(|(i, _)| PyRule::from_rules(i, self.rules.clone())) .collect() } /// Finds a rule by selector. fn select(&self, id: &str) -> PyResult> { let selector = Selector::try_from(id.to_owned()) - .map_err(|err| PyValueError::new_err(format!("error creating selector: {}", err)))?; + .map_err(|err| PyValueError::new_err(format!("{}", err)))?; Ok(self .rules @@ -640,7 +804,7 @@ impl PyRules { .iter() .enumerate() .filter(|(_, rule)| selector.is_match(rule.id())) - .map(|(i, _)| PyRule::from_rule(i, self.rules.clone())) + .map(|(i, _)| PyRule::from_rules(i, self.rules.clone())) .collect()) } @@ -652,12 +816,12 @@ impl PyRules { /// Returns: /// suggestions (Union[List[Suggestion], List[List[Suggestion]]]): /// The computed suggestions. Batched if the input is batched. - #[text_signature = "(sentence_or_sentences)"] - fn suggest(&self, py: Python, sentence_or_sentences: PyObject) -> PyResult { - text_guard(py, sentence_or_sentences, |sentence| { + #[text_signature = "(text_or_texts)"] + fn suggest(&self, py: Python, text_or_texts: PyObject) -> PyResult { + text_guard(py, text_or_texts, |text| { self.rules .read() - .suggest(&sentence) + .suggest(&text) .into_iter() .map(|x| PyCell::new(py, PySuggestion::from(x))) .collect::>>() @@ -724,9 +888,9 @@ impl PyRules { } pub fn __getstate__(&self, py: Python) -> PyResult { - // rwlock is serialized the same way as the inner type Ok(PyBytes::new( py, + // rwlock serialization is transparent &bincode::serialize(&self.rules) .map_err(|_| PyValueError::new_err("serializing state with `bincode` failed"))?, ) @@ -734,13 +898,23 @@ impl PyRules { } } +#[pymodule] +fn spell(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + #[pymodule] fn nlprule(_py: Python, m: &PyModule) -> PyResult<()> { m.add("__version__", env!("CARGO_PKG_VERSION"))?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_wrapped(wrap_pymodule!(spell))?; + Ok(()) } diff --git a/python/test.py b/python/test.py index 086b1a4..bc397c6 100644 --- a/python/test.py +++ b/python/test.py @@ -138,3 +138,23 @@ def test_rules_can_be_disabled(tokenizer_and_rules): rule.disable() assert len(rules.suggest("I can due his homework")) == 0 + +def test_spell_options_can_be_read(tokenizer_and_rules): + (tokenizer, rules) = tokenizer_and_rules + + assert rules.spell.options.max_distance > 0 + assert rules.spell.options.variant is None + +def test_spell_options_can_be_set(tokenizer_and_rules): + (tokenizer, rules) = tokenizer_and_rules + + with pytest.raises(ValueError): + rules.spell.options.variant = "en_INVALID" + + rules.spell.options.variant = "en_GB" + assert rules.spell.options.variant == "en_GB" + +def test_spellchecker_works(tokenizer_and_rules): + (tokenizer, rules) = tokenizer_and_rules + + print(rules.spell.search("lämp")) \ No newline at end of file From 93881da02ff15c8620fd7fbb150fecc2fa62e19f Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Sat, 13 Mar 2021 11:55:25 +0100 Subject: [PATCH 08/16] add support for multiword spelling correction, add contraction mapping --- bench/__init__.py | 7 +-- build/README.md | 1 + build/make_build_dir.py | 80 +++++++++++++++++++++++++++ nlprule/Cargo.toml | 1 + nlprule/configs/de/spellchecker.json | 3 +- nlprule/configs/en/rules.json | 3 +- nlprule/configs/en/spellchecker.json | 5 +- nlprule/configs/es/spellchecker.json | 3 +- nlprule/src/compile/impls.rs | 81 +++++++++++++++++++++++++--- nlprule/src/compile/mod.rs | 62 +++++++++++++++++---- nlprule/src/spellcheck/mod.rs | 81 ++++++++++++++++++++-------- nlprule/src/tokenizer.rs | 3 +- nlprule/src/tokenizer/multiword.rs | 16 +++--- nlprule/src/types.rs | 7 ++- nlprule/tests/tests.rs | 2 + python/src/lib.rs | 19 ++++--- 16 files changed, 309 insertions(+), 65 deletions(-) diff --git a/bench/__init__.py b/bench/__init__.py index eebff50..3dfb05c 100644 --- a/bench/__init__.py +++ b/bench/__init__.py @@ -115,11 +115,8 @@ def suggest(self, sentence: str) -> Set[Suggestion]: class NLPRule: def __init__(self, lang_code: str): self.tokenizer = nlprule.Tokenizer(f"storage/{lang_code}_tokenizer.bin") - self.rules = nlprule.Rules( - f"storage/{lang_code}_rules.bin", - self.tokenizer, - spellchecker_options={"variant": "en_US", "max_distance": 0, "prefix": 0,}, - ) + self.rules = nlprule.Rules(f"storage/{lang_code}_rules.bin", self.tokenizer) + self.rules.spell.options.variant = "en_US" def suggest(self, sentence: str) -> Set[Suggestion]: suggestions = { diff --git a/build/README.md b/build/README.md index c826acc..458f7da 100644 --- a/build/README.md +++ b/build/README.md @@ -79,6 +79,7 @@ python build/make_build_dir.py \ --chunker_token_model=$HOME/Downloads/nlprule/en-token.bin \ --chunker_pos_model=$HOME/Downloads/nlprule/en-pos-maxent.bin \ --chunker_chunk_model=$HOME/Downloads/nlprule/en-chunker.bin \ + --spell_map_path=$LT_PATH/org/languagetool/rules/en/contractions.txt \ --out_dir=data/en ``` diff --git a/build/make_build_dir.py b/build/make_build_dir.py index 0cd357c..849adac 100644 --- a/build/make_build_dir.py +++ b/build/make_build_dir.py @@ -87,6 +87,45 @@ def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): f.write(dump_bytes.decode(result["encoding"] or "utf-8")) +def proc_spelling_text(in_paths, out_path, lang_code): + with open(out_path, "w") as f: + for in_path in in_paths: + if in_path.exists(): + for line in open(in_path): + # strip comments + comment_index = line.find("#") + if comment_index != -1: + line = line[:comment_index] + + line = line.strip() + if len(line) == 0: + continue + + try: + word, suffix = line.split("/") + + assert lang_code == "de", "Flags are only supported for German!" + + for flag in suffix: + assert flag != "Ä" + if flag == "A" and word.endswith("e"): + flag = "Ä" + + f.write(word + "\n") + + for ending in { + "S": ["s"], + "N": ["n"], + "E": ["e"], + "F": ["in"], + "A": ["e", "er", "es", "en", "em"], + "Ä": ["r", "s", "n", "m"], + }[flag]: + f.write(word + ending + "\n") + except ValueError: + f.write(line + "\n") + + if __name__ == "__main__": parser = ArgumentParser( description=""" @@ -139,6 +178,12 @@ def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): default=None, help="Path to the OpenNLP chunker binary. See token model message for details.", ) + parser.add_argument( + "--spell_map_path", + default=None, + action="append", + help="Paths to files containing a mapping from incorrect words to correct ones e.g. contractions.txt for English.", + ) parser.add_argument( "--out_dir", type=lambda p: Path(p).absolute(), @@ -180,6 +225,41 @@ def dump_dict(out_path, lt_dir, tag_dict_path, tag_info_path): dump_dict( args.out_dir / "spell" / f"{variant_name}.dump", args.lt_dir, dic, info, ) + proc_spelling_text( + [ + ( + dic / ".." / ("spelling_" + variant_name.replace("_", "-") + ".txt") + ).resolve(), + ( + dic / ".." / ("spelling-" + variant_name.replace("_", "-") + ".txt") + ).resolve(), + ], + args.out_dir / "spell" / f"{variant_name}.txt", + args.lang_code, + ) + + proc_spelling_text( + [ + args.lt_dir + / "org" + / "languagetool" + / "resource" + / args.lang_code + / "hunspell" + / "spelling.txt" + ], + args.out_dir / "spell" / "spelling.txt", + args.lang_code, + ) + + with open(args.out_dir / "spell" / "map.txt", "w") as f: + for path in args.spell_map_path: + for line in open(path): + if line.startswith("#"): + continue + + assert "#" not in line + f.write(line) if ( args.chunker_token_model is not None diff --git a/nlprule/Cargo.toml b/nlprule/Cargo.toml index 6571798..a0c59d3 100644 --- a/nlprule/Cargo.toml +++ b/nlprule/Cargo.toml @@ -31,6 +31,7 @@ srx = { version = "^0.1.2", features = ["serde"] } lazycell = "1" cfg-if = "1" fnv = "1" +unicode_categories = "0.1" rayon-cond = "0.1" rayon = "1.5" diff --git a/nlprule/configs/de/spellchecker.json b/nlprule/configs/de/spellchecker.json index 52f6a41..0500649 100644 --- a/nlprule/configs/de/spellchecker.json +++ b/nlprule/configs/de/spellchecker.json @@ -3,5 +3,6 @@ "de_AT", "de_DE", "de_CH" - ] + ], + "split_hyphens": true } \ No newline at end of file diff --git a/nlprule/configs/en/rules.json b/nlprule/configs/en/rules.json index a1a27a0..2fd88b3 100644 --- a/nlprule/configs/en/rules.json +++ b/nlprule/configs/en/rules.json @@ -3,5 +3,6 @@ "ignore_ids": [ "GRAMMAR/PRP_MD_NN/2", "TYPOS/VERB_APOSTROPHE_S/3" - ] + ], + "split_hyphens": true } \ No newline at end of file diff --git a/nlprule/configs/en/spellchecker.json b/nlprule/configs/en/spellchecker.json index ab2e94b..d9774ff 100644 --- a/nlprule/configs/en/spellchecker.json +++ b/nlprule/configs/en/spellchecker.json @@ -2,8 +2,7 @@ "variants": [ "en_GB", "en_US", - "en_ZA", - "en_CA", "en_AU" - ] + ], + "split_hyphens": true } \ No newline at end of file diff --git a/nlprule/configs/es/spellchecker.json b/nlprule/configs/es/spellchecker.json index de221b3..51a957d 100644 --- a/nlprule/configs/es/spellchecker.json +++ b/nlprule/configs/es/spellchecker.json @@ -1,3 +1,4 @@ { - "variants": [] + "variants": [], + "split_hyphens": true } \ No newline at end of file diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 758b330..3368abb 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -37,9 +37,12 @@ use crate::{ use super::{parse_structure::BuildInfo, Error}; impl Spell { - pub(in crate::compile) fn from_dumps>( - spell_dir_path: S, + pub(in crate::compile) fn from_dumps, S2: AsRef>( + spell_dir_path: S1, + map_path: S2, + extra_words: &HashSet, lang_options: SpellLangOptions, + tokenizer: &Tokenizer, ) -> io::Result { let mut words: HashMap = DefaultHashMap::new(); let mut max_freq = 0; @@ -50,7 +53,7 @@ impl Spell { .join(variant.as_str()) .with_extension("dump"); - let reader = BufReader::new(File::open(spell_path)?); + let reader = BufReader::new(File::open(&spell_path)?); for line in reader.lines() { match line? .trim() @@ -61,7 +64,16 @@ impl Spell { [freq, word] => { // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. let freq = freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; - let value = words.entry(word.to_string()).or_default(); + let value = words.entry((*word).to_owned()).or_default(); + + if tokenizer.get_token_strs(word).len() > 1 { + warn!( + "phrase '{}' ignored by {} spellchecker.", + word, + variant.as_str() + ); + continue; + } max_freq = cmp::max(max_freq, freq); @@ -71,6 +83,29 @@ impl Spell { _ => continue, } } + + let extra_word_path = spell_path.with_extension("txt"); + let reader = BufReader::new(File::open(&extra_word_path)?); + for line in reader + .lines() + .collect::>>()? + .into_iter() + .chain(extra_words.iter().cloned()) + { + let word = line.trim(); + + if tokenizer.get_token_strs(word).len() > 1 { + warn!( + "phrase '{}' ignored by {} spellchecker.", + word, + variant.as_str() + ); + continue; + } + + let value = words.entry((*word).to_owned()).or_default(); + value.add_variant(i); + } } let mut words: Vec<_> = words .into_iter() @@ -78,10 +113,37 @@ impl Spell { .collect(); words.sort_by(|(a, _), (b, _)| a.cmp(b)); - let map = + let fst = fst::Map::from_iter(words.into_iter()).expect("words are lexicographically sorted."); - Ok(Spell::new(map.into_fst().to_vec(), max_freq, lang_options)) + let mut map = DefaultHashMap::new(); + let reader = BufReader::new(File::open(map_path.as_ref())?); + for line in reader.lines() { + let line = line?; + + let mut parts = line.split('='); + let wrong = parts + .next() + .expect("spell map line must have part before =") + .to_owned(); + let right = parts + .next() + .expect("spell map line must have part after =") + .to_owned(); + + // map lookup happens on token level, so the key has to be exactly one token + assert_eq!(tokenizer.get_token_strs(&wrong).len(), 1); + + map.insert(wrong, right); + assert!(parts.next().is_none()); + } + + Ok(Spell::new( + fst.into_fst().to_vec(), + max_freq, + map, + lang_options, + )) } } @@ -226,6 +288,7 @@ impl Tagger { impl MultiwordTagger { pub(in crate::compile) fn from_dump>( dump: P, + extra_phrases: impl Iterator, info: &BuildInfo, ) -> Result { let reader = BufReader::new(File::open(dump.as_ref())?); @@ -246,7 +309,11 @@ impl MultiwordTagger { .collect::>() .join(" "); let pos = info.tagger().id_tag(tab_split[1]).to_owned_id(); - multiwords.push((word, pos)); + multiwords.push((word, Some(pos))); + } + + for phrase in extra_phrases { + multiwords.push((phrase, None)); } Ok((MultiwordTaggerFields { multiwords }).into()) diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index 123776d..f3f0e0b 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -5,7 +5,7 @@ use fs_err as fs; use std::{ hash::{Hash, Hasher}, - io::{self, BufReader, BufWriter}, + io::{self, BufRead, BufReader, BufWriter}, num::ParseIntError, path::{Path, PathBuf}, str::FromStr, @@ -40,6 +40,8 @@ struct BuildFilePaths { srx_path: PathBuf, common_words_path: PathBuf, spell_dir_path: PathBuf, + spell_map_path: PathBuf, + spell_extra_path: PathBuf, } impl BuildFilePaths { @@ -58,6 +60,8 @@ impl BuildFilePaths { srx_path: p.join("segment.srx"), common_words_path: p.join("common.txt"), spell_dir_path: p.join("spell"), + spell_map_path: p.join("spell/map.txt"), + spell_extra_path: p.join("spell/spelling.txt"), } } } @@ -176,6 +180,41 @@ pub fn compile( } else { None }; + + info!("Creating tokenizer."); + + let mut tokenizer = Tokenizer::from_xml( + &paths.disambiguation_path, + &mut build_info, + chunker, + None, + srx::SRX::from_str(&fs::read_to_string(&paths.srx_path)?)?.language_rules(lang_code), + tokenizer_lang_options, + )?; + + let mut extra_phrases = DefaultHashSet::new(); + let mut extra_spell_words = DefaultHashSet::new(); + + // comments must already be stripped from this file such that each line contains one word or phrase + let reader = BufReader::new(File::open(&paths.spell_extra_path)?); + for line in reader.lines() { + let line = line?; + let content = line.trim(); + + match tokenizer.get_token_strs(content).len() { + 0 => { + return Err(Error::Unexpected(format!( + "empty lines in {} are not allowed.", + paths.spell_extra_path.display() + ))) + } + // if the content is exactly one token, we just add it to the spellchecker regularly + 1 => extra_spell_words.insert(content.to_owned()), + // if the content is a phrase (i.e multiple tokens) we add it to the multiword tagger, since words found by the multiword tagger are considered correct + _ => extra_phrases.insert(content.to_owned()), + }; + } + let multiword_tagger = if paths.multiword_tag_path.exists() { info!( "{} exists. Building multiword tagger.", @@ -183,26 +222,27 @@ pub fn compile( ); Some(MultiwordTagger::from_dump( paths.multiword_tag_path, + extra_phrases.into_iter(), &build_info, )?) } else { None }; + tokenizer.multiword_tagger = multiword_tagger; + tokenizer.to_writer(&mut tokenizer_dest)?; - let spellchecker = Spell::from_dumps(paths.spell_dir_path, spellchecker_lang_options)?; + info!("Creating spellchecker."); - info!("Creating tokenizer."); - let tokenizer = Tokenizer::from_xml( - &paths.disambiguation_path, - &mut build_info, - chunker, - multiword_tagger, - srx::SRX::from_str(&fs::read_to_string(&paths.srx_path)?)?.language_rules(lang_code), - tokenizer_lang_options, + let spellchecker = Spell::from_dumps( + paths.spell_dir_path, + paths.spell_map_path, + &extra_spell_words, + spellchecker_lang_options, + &tokenizer, )?; - tokenizer.to_writer(&mut tokenizer_dest)?; info!("Creating grammar rules."); + let rules = Rules::from_xml( &paths.grammar_path, &mut build_info, diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index d13ae0b..6d43c6c 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -1,16 +1,19 @@ +use fst::{IntoStreamer, Map, MapBuilder, Streamer}; +use serde::{Deserialize, Serialize}; use std::{ cmp::Ordering, collections::{BinaryHeap, HashSet}, ops::{Deref, DerefMut}, }; +use unicode_categories::UnicodeCategories; -use fst::{IntoStreamer, Map, MapBuilder, Streamer}; -use serde::{Deserialize, Serialize}; - -use crate::{types::*, Error}; +use crate::{ + types::*, + utils::{apply_to_first, is_title_case}, + Error, +}; mod levenshtein; - mod spell_int { #[derive(Debug, Clone, Default, Copy)] pub(crate) struct SpellInt(pub(super) u64); @@ -30,7 +33,7 @@ mod spell_int { assert!(freq < FreqType::MAX as usize); // erase previous frequency - self.0 = self.0 & (u64::MAX - FreqType::MAX as u64); + self.0 &= u64::MAX - FreqType::MAX as u64; // set new frequency self.0 |= freq as u64; } @@ -81,7 +84,7 @@ mod spell_int { pub(crate) use spell_int::SpellInt; -#[derive(Debug, Clone, Default, PartialEq)] +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] pub struct Candidate { score: f32, distance: usize, @@ -159,6 +162,7 @@ impl<'a> Drop for SpellOptionsGuard<'a> { pub(crate) struct SpellLangOptions { /// Variants of the language (e. g. "en_US", "en_GB") to consider for spellchecking. pub variants: Vec, + pub split_hyphens: bool, } impl Default for SpellOptions { @@ -188,18 +192,26 @@ impl Variant { pub struct Spell { fst: Vec, max_freq: usize, + map: DefaultHashMap, lang_options: SpellLangOptions, options: SpellOptions, + // fields below are computed depending on the selected variant used_variant: Option, used_fst: Vec, - used_set: HashSet, + used_set: DefaultHashSet, } impl Spell { - pub(crate) fn new(fst: Vec, max_freq: usize, lang_options: SpellLangOptions) -> Self { + pub(crate) fn new( + fst: Vec, + max_freq: usize, + map: DefaultHashMap, + lang_options: SpellLangOptions, + ) -> Self { let mut spell = Spell { fst, max_freq, + map, lang_options, options: SpellOptions::default(), ..Default::default() @@ -271,9 +283,31 @@ impl Spell { self.used_set = set; } - pub fn search(&self, word: &str) -> Option> { - if self.used_variant.is_none() || word.is_empty() || self.used_set.contains(word) { - return None; + fn check_flat(&self, word: &str) -> bool { + self.used_variant.is_none() + || word.is_empty() + || word.chars().all(|x| x.is_punctuation() || x.is_numeric()) + || self.used_set.contains(word) + || (is_title_case(word) + && self.check(&apply_to_first(word, |x| x.to_lowercase().collect()))) + } + + pub fn check(&self, word: &str) -> bool { + self.check_flat(word) + || (self.lang_options.split_hyphens + && word + .split(&['-', '\u{2010}', '\u{2011}'][..]) + .all(|x| self.check_flat(x))) + } + + pub fn search(&self, word: &str) -> Vec { + if let Some(candidate) = self.map.get(word) { + return vec![Candidate { + score: 0., // numerical values here do not matter since there is always exactly one candidate - ranking is irrelevant + freq: 0, + distance: 0, + term: candidate.to_owned(), + }]; } let used_fst = Map::new(self.used_fst.as_slice()).expect("used fst must be valid."); @@ -301,22 +335,27 @@ impl Spell { } } - Some(out.into_sorted_vec()) + out.into_sorted_vec() } pub fn suggest(&self, tokens: &[Token]) -> Vec { let mut suggestions = Vec::new(); for token in tokens { - if let Some(candidates) = self.search(token.word.text.as_ref()) { - suggestions.push(Suggestion { - source: "SPELLCHECK/SINGLE".into(), - message: "Possibly misspelled word.".into(), - start: token.char_span.0, - end: token.char_span.1, - replacements: candidates.into_iter().map(|x| x.term.to_owned()).collect(), - }) + let text = token.word.text.as_ref(); + + if token.ignore_spelling || self.check(text) { + continue; } + + let candidates = self.search(text); + suggestions.push(Suggestion { + source: "SPELLCHECK/SINGLE".into(), + message: "Possibly misspelled word.".into(), + start: token.char_span.0, + end: token.char_span.1, + replacements: candidates.into_iter().map(|x| x.term).collect(), + }); } suggestions diff --git a/nlprule/src/tokenizer.rs b/nlprule/src/tokenizer.rs index 6fdcfc9..8a3d2c5 100644 --- a/nlprule/src/tokenizer.rs +++ b/nlprule/src/tokenizer.rs @@ -201,7 +201,7 @@ impl Tokenizer { self.disambiguate_up_to_id(tokens, None) } - fn get_token_strs<'t>(&self, text: &'t str) -> Vec<&'t str> { + pub(crate) fn get_token_strs<'t>(&self, text: &'t str) -> Vec<&'t str> { let mut tokens = Vec::new(); let split_char = |c: char| c.is_whitespace() || crate::utils::splitting_chars().contains(c); @@ -279,6 +279,7 @@ impl Tokenizer { char_span: (char_start, current_char), byte_span: (byte_start, byte_start + x.len()), is_sentence_end, + ignore_spelling: false, has_space_before: sentence[..byte_start].ends_with(char::is_whitespace), chunks: Vec::new(), multiword_data: None, diff --git a/nlprule/src/tokenizer/multiword.rs b/nlprule/src/tokenizer/multiword.rs index e3e9bfc..e53760e 100644 --- a/nlprule/src/tokenizer/multiword.rs +++ b/nlprule/src/tokenizer/multiword.rs @@ -8,7 +8,7 @@ use super::tag::Tagger; #[derive(Serialize, Deserialize)] pub(crate) struct MultiwordTaggerFields { - pub(crate) multiwords: Vec<(String, owned::PosId)>, + pub(crate) multiwords: Vec<(String, Option)>, } impl From for MultiwordTagger { @@ -35,7 +35,7 @@ impl From for MultiwordTagger { pub struct MultiwordTagger { #[serde(skip)] matcher: AhoCorasick, - multiwords: Vec<(String, owned::PosId)>, + multiwords: Vec<(String, Option)>, } impl MultiwordTagger { @@ -66,10 +66,14 @@ impl MultiwordTagger { let (word, pos) = &self.multiwords[m.pattern()]; // end index is inclusive for token in tokens[*start..(*end + 1)].iter_mut() { - token.multiword_data = Some(WordData::new( - tagger.id_word(word.as_str().into()), - pos.as_ref_id(), - )); + if let Some(pos) = pos.as_ref() { + token.multiword_data = Some(WordData::new( + tagger.id_word(word.as_str().into()), + pos.as_ref_id(), + )); + } + + token.ignore_spelling = true; } } } diff --git a/nlprule/src/types.rs b/nlprule/src/types.rs index 9d11260..68104bb 100644 --- a/nlprule/src/types.rs +++ b/nlprule/src/types.rs @@ -198,7 +198,7 @@ pub struct IncompleteToken<'t> { pub byte_span: (usize, usize), /// Char start (inclusive) and end (exclusive) of this token in the sentence. pub char_span: (usize, usize), - /// Whether this token is the last token in the sentence- + /// Whether this token is the last token in the sentence. pub is_sentence_end: bool, /// Whether this token has one or more whitespace characters before. pub has_space_before: bool, @@ -206,6 +206,8 @@ pub struct IncompleteToken<'t> { pub chunks: Vec, /// A *multiword* lemma and part-of-speech tag. Set if the token was found in a list of phrases. pub multiword_data: Option>, + /// Whether to ignore spelling for this token. + pub ignore_spelling: bool, /// The sentence this token is in. pub sentence: &'t str, /// The tagger used for lookup related to this token. @@ -227,6 +229,7 @@ pub struct Token<'t> { pub word: Word<'t>, pub char_span: (usize, usize), pub byte_span: (usize, usize), + pub ignore_spelling: bool, pub has_space_before: bool, pub chunks: Vec, pub sentence: &'t str, @@ -249,6 +252,7 @@ impl<'t> Token<'t> { ), char_span: (0, 0), byte_span: (0, 0), + ignore_spelling: true, has_space_before: false, chunks: Vec::new(), sentence, @@ -298,6 +302,7 @@ impl<'t> From> for Token<'t> { word, byte_span: data.byte_span, char_span: data.char_span, + ignore_spelling: data.ignore_spelling, has_space_before: data.has_space_before, chunks: data.chunks, sentence: data.sentence, diff --git a/nlprule/tests/tests.rs b/nlprule/tests/tests.rs index f71789f..0990939 100644 --- a/nlprule/tests/tests.rs +++ b/nlprule/tests/tests.rs @@ -57,5 +57,7 @@ fn spellchecker_works() -> Result<(), Error> { let mut rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB")?); + println!("{:?}", rules.suggest("Unicode punctuation: —")); + Ok(()) } diff --git a/python/src/lib.rs b/python/src/lib.rs index d3321bc..40fceed 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -714,13 +714,18 @@ impl PySpell { Ok(()) } - fn search(&self, word: &str) -> Option> { - self.rules.read().spell().search(word).map(|candidates| { - candidates - .into_iter() - .map(|candidate| PyCandidate { candidate }) - .collect::>() - }) + fn check(&self, word: &str) -> bool { + self.rules.read().spell().check(word) + } + + fn search(&self, word: &str) -> Vec { + self.rules + .read() + .spell() + .search(word) + .into_iter() + .map(|candidate| PyCandidate { candidate }) + .collect() } } From 16781b57c3093de888fa967d86e3e7c58d80cc1b Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Mon, 15 Mar 2021 14:42:50 +0100 Subject: [PATCH 09/16] proper multiword spellchecker support, ignore action=ignore_spelling tokens --- build/make_build_dir.py | 2 +- nlprule/src/bin/run.rs | 3 +- nlprule/src/compile/impls.rs | 109 ++++++++++++++++++------- nlprule/src/compile/mod.rs | 29 +------ nlprule/src/compile/parse_structure.rs | 2 +- nlprule/src/rule/disambiguation.rs | 8 ++ nlprule/src/rule/grammar.rs | 17 ++-- nlprule/src/spellcheck/mod.rs | 81 +++++++++++++++--- nlprule/src/tokenizer/multiword.rs | 15 ++-- python/src/lib.rs | 14 ---- 10 files changed, 175 insertions(+), 105 deletions(-) diff --git a/build/make_build_dir.py b/build/make_build_dir.py index 849adac..80f9b7d 100644 --- a/build/make_build_dir.py +++ b/build/make_build_dir.py @@ -253,7 +253,7 @@ def proc_spelling_text(in_paths, out_path, lang_code): ) with open(args.out_dir / "spell" / "map.txt", "w") as f: - for path in args.spell_map_path: + for path in args.spell_map_path or []: for line in open(path): if line.startswith("#"): continue diff --git a/nlprule/src/bin/run.rs b/nlprule/src/bin/run.rs index 6bc46fb..9825769 100644 --- a/nlprule/src/bin/run.rs +++ b/nlprule/src/bin/run.rs @@ -21,7 +21,8 @@ fn main() { let opts = Opts::parse(); let tokenizer = Arc::new(Tokenizer::new(opts.tokenizer).unwrap()); - let rules = Rules::new(opts.rules, tokenizer.clone()).unwrap(); + let mut rules = Rules::new(opts.rules, tokenizer.clone()).unwrap(); + rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB").unwrap()); let tokens = tokenizer.pipe(&opts.text); diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 3368abb..1f1f7e9 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -19,6 +19,7 @@ use crate::{ composition::{GraphId, Matcher, PosMatcher, TextMatcher}, Engine, }, + grammar::PosReplacer, id::Category, DisambiguationRule, MatchGraph, Rule, }, @@ -37,14 +38,17 @@ use crate::{ use super::{parse_structure::BuildInfo, Error}; impl Spell { - pub(in crate::compile) fn from_dumps, S2: AsRef>( - spell_dir_path: S1, - map_path: S2, - extra_words: &HashSet, + pub(in crate::compile) fn from_dumps( + spell_dir_path: impl AsRef, + map_path: impl AsRef, + global_word_path: impl AsRef, + build_info: &mut BuildInfo, lang_options: SpellLangOptions, tokenizer: &Tokenizer, ) -> io::Result { - let mut words: HashMap = DefaultHashMap::new(); + let mut words: DefaultHashMap = DefaultHashMap::new(); + let mut multiwords: DefaultHashMap, SpellInt)>> = + DefaultHashMap::new(); let mut max_freq = 0; for (i, variant) in lang_options.variants.iter().enumerate() { @@ -66,12 +70,24 @@ impl Spell { let freq = freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; let value = words.entry((*word).to_owned()).or_default(); - if tokenizer.get_token_strs(word).len() > 1 { - warn!( - "phrase '{}' ignored by {} spellchecker.", - word, - variant.as_str() - ); + let tokens = tokenizer.get_token_strs(word); + + if tokens.len() > 1 { + let mut int = SpellInt::default(); + int.add_variant(i); + + multiwords + .entry(tokens[0].to_owned()) + .or_insert_with(Vec::new) + .push(( + tokens[1..] + .iter() + .filter(|x| !x.trim().is_empty()) + .map(|x| (*x).to_owned()) + .collect(), + int, + )); + continue; } @@ -84,27 +100,62 @@ impl Spell { } } + let global_word_reader = BufReader::new(File::open(global_word_path.as_ref())?); + let extra_word_path = spell_path.with_extension("txt"); let reader = BufReader::new(File::open(&extra_word_path)?); - for line in reader - .lines() - .collect::>>()? - .into_iter() - .chain(extra_words.iter().cloned()) - { + for line in reader.lines().chain(global_word_reader.lines()) { + let line = line?; let word = line.trim(); - if tokenizer.get_token_strs(word).len() > 1 { - warn!( - "phrase '{}' ignored by {} spellchecker.", - word, - variant.as_str() - ); + let tokens = tokenizer.get_token_strs(word); + + if tokens.len() > 1 { + let mut int = SpellInt::default(); + int.add_variant(i); + multiwords + .entry(tokens[0].to_owned()) + .or_insert_with(Vec::new) + .push(( + tokens[1..] + .iter() + .filter(|x| !x.trim().is_empty()) + .map(|x| (*x).to_owned()) + .collect(), + int, + )); + continue; } - let value = words.entry((*word).to_owned()).or_default(); - value.add_variant(i); + if word.contains('_') { + assert!(!word.contains('\\')); // escaped underlines are not supported + let mut parts = word.split('_'); + + let prefix = parts.next().unwrap(); + let suffix = parts.next().unwrap(); + + // this will presumably always be covered by the extra suffixes, but add it just to make sure + words + .entry(format!("{}{}", prefix, suffix)) + .or_default() + .add_variant(i); + + let replacer = PosReplacer { + matcher: PosMatcher::new( + Matcher::new_regex(Regex::new("^VER:.*".into()), false, true), + build_info, + ), + }; + + for new_suffix in replacer.apply(suffix, tokenizer) { + let new_word = format!("{}{}", prefix, new_suffix); + + words.entry(new_word).or_default().add_variant(i); + } + } + + words.entry((*word).to_owned()).or_default().add_variant(i); } } let mut words: Vec<_> = words @@ -140,6 +191,7 @@ impl Spell { Ok(Spell::new( fst.into_fst().to_vec(), + multiwords, max_freq, map, lang_options, @@ -288,7 +340,6 @@ impl Tagger { impl MultiwordTagger { pub(in crate::compile) fn from_dump>( dump: P, - extra_phrases: impl Iterator, info: &BuildInfo, ) -> Result { let reader = BufReader::new(File::open(dump.as_ref())?); @@ -309,11 +360,7 @@ impl MultiwordTagger { .collect::>() .join(" "); let pos = info.tagger().id_tag(tab_split[1]).to_owned_id(); - multiwords.push((word, Some(pos))); - } - - for phrase in extra_phrases { - multiwords.push((phrase, None)); + multiwords.push((word, pos)); } Ok((MultiwordTaggerFields { multiwords }).into()) diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index f3f0e0b..3f121a8 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -5,7 +5,7 @@ use fs_err as fs; use std::{ hash::{Hash, Hasher}, - io::{self, BufRead, BufReader, BufWriter}, + io::{self, BufReader, BufWriter}, num::ParseIntError, path::{Path, PathBuf}, str::FromStr, @@ -192,29 +192,6 @@ pub fn compile( tokenizer_lang_options, )?; - let mut extra_phrases = DefaultHashSet::new(); - let mut extra_spell_words = DefaultHashSet::new(); - - // comments must already be stripped from this file such that each line contains one word or phrase - let reader = BufReader::new(File::open(&paths.spell_extra_path)?); - for line in reader.lines() { - let line = line?; - let content = line.trim(); - - match tokenizer.get_token_strs(content).len() { - 0 => { - return Err(Error::Unexpected(format!( - "empty lines in {} are not allowed.", - paths.spell_extra_path.display() - ))) - } - // if the content is exactly one token, we just add it to the spellchecker regularly - 1 => extra_spell_words.insert(content.to_owned()), - // if the content is a phrase (i.e multiple tokens) we add it to the multiword tagger, since words found by the multiword tagger are considered correct - _ => extra_phrases.insert(content.to_owned()), - }; - } - let multiword_tagger = if paths.multiword_tag_path.exists() { info!( "{} exists. Building multiword tagger.", @@ -222,7 +199,6 @@ pub fn compile( ); Some(MultiwordTagger::from_dump( paths.multiword_tag_path, - extra_phrases.into_iter(), &build_info, )?) } else { @@ -236,7 +212,8 @@ pub fn compile( let spellchecker = Spell::from_dumps( paths.spell_dir_path, paths.spell_map_path, - &extra_spell_words, + paths.spell_extra_path, + &mut build_info, spellchecker_lang_options, &tokenizer, )?; diff --git a/nlprule/src/compile/parse_structure.rs b/nlprule/src/compile/parse_structure.rs index 5fe417c..44aeb4d 100644 --- a/nlprule/src/compile/parse_structure.rs +++ b/nlprule/src/compile/parse_structure.rs @@ -1054,7 +1054,7 @@ impl DisambiguationRule { }) .collect(), )), - Some("ignore_spelling") => Ok(Disambiguation::Nop), // ignore_spelling can be ignored since we dont check spelling + Some("ignore_spelling") => Ok(Disambiguation::IgnoreSpelling), Some("immunize") => Ok(Disambiguation::Nop), // immunize can probably not be ignored Some("filterall") => { let mut disambig = Vec::new(); diff --git a/nlprule/src/rule/disambiguation.rs b/nlprule/src/rule/disambiguation.rs index 2fc80b4..6f067df 100644 --- a/nlprule/src/rule/disambiguation.rs +++ b/nlprule/src/rule/disambiguation.rs @@ -44,6 +44,7 @@ pub enum Disambiguation { Replace(Vec), Filter(Vec>>), Unify(Vec>, Vec>, Vec), + IgnoreSpelling, Nop, } @@ -190,6 +191,13 @@ impl Disambiguation { } } } + Disambiguation::IgnoreSpelling => { + for group in groups { + for token in group { + token.ignore_spelling = true; + } + } + } Disambiguation::Nop => {} } } diff --git a/nlprule/src/rule/grammar.rs b/nlprule/src/rule/grammar.rs index a7cc170..3b10a07 100644 --- a/nlprule/src/rule/grammar.rs +++ b/nlprule/src/rule/grammar.rs @@ -64,7 +64,7 @@ pub struct PosReplacer { } impl PosReplacer { - fn apply(&self, text: &str, tokenizer: &Tokenizer) -> Option { + pub fn apply(&self, text: &str, tokenizer: &Tokenizer) -> Vec { let mut candidates: Vec<_> = tokenizer .tagger() .get_tags(text) @@ -75,13 +75,13 @@ impl PosReplacer { .get_group_members(&x.lemma.as_ref().to_string()); let mut data = Vec::new(); for word in group_words { - if let Some(i) = tokenizer + if let Some(_i) = tokenizer .tagger() .get_tags(word) .iter() .position(|x| self.matcher.is_match(&x.pos)) { - data.push((word.to_string(), i)); + data.push(word.to_string()); } } data @@ -89,12 +89,9 @@ impl PosReplacer { .rev() .flatten() .collect(); - candidates.sort_by(|(_, a), (_, b)| a.cmp(b)); - if candidates.is_empty() { - None - } else { - Some(candidates.remove(0).0) - } + candidates.sort_unstable(); + candidates.dedup(); + candidates } } @@ -111,7 +108,7 @@ impl Match { let text = graph.by_id(self.id).text(graph.tokens()[0].sentence); let mut text = if let Some(replacer) = &self.pos_replacer { - replacer.apply(text, tokenizer)? + replacer.apply(text, tokenizer).into_iter().next()? } else { text.to_string() }; diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index 6d43c6c..194d06a 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -15,7 +15,11 @@ use crate::{ mod levenshtein; mod spell_int { - #[derive(Debug, Clone, Default, Copy)] + use std::cmp; + + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Clone, Default, Copy, Serialize, Deserialize)] pub(crate) struct SpellInt(pub(super) u64); type FreqType = u8; @@ -32,10 +36,12 @@ mod spell_int { pub fn update_freq(&mut self, freq: usize) { assert!(freq < FreqType::MAX as usize); + let prev_freq = self.freq(); // erase previous frequency self.0 &= u64::MAX - FreqType::MAX as u64; - // set new frequency - self.0 |= freq as u64; + // set new frequency, strictly speaking we would have to store a frequency for each variant + // but that would need significantly more space, so we just store the highest frequency + self.0 |= cmp::max(prev_freq, freq) as u64; } pub fn add_variant(&mut self, index: usize) { @@ -191,25 +197,33 @@ impl Variant { #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Spell { fst: Vec, + multiwords: DefaultHashMap, SpellInt)>>, max_freq: usize, map: DefaultHashMap, lang_options: SpellLangOptions, options: SpellOptions, // fields below are computed depending on the selected variant + #[serde(skip)] used_variant: Option, + #[serde(skip)] used_fst: Vec, + #[serde(skip)] + used_multiwords: DefaultHashMap>>, + #[serde(skip)] used_set: DefaultHashSet, } impl Spell { pub(crate) fn new( fst: Vec, + multiwords: DefaultHashMap, SpellInt)>>, max_freq: usize, map: DefaultHashMap, lang_options: SpellLangOptions, ) -> Self { let mut spell = Spell { fst, + multiwords, max_freq, map, lang_options, @@ -251,6 +265,7 @@ impl Spell { } else { self.used_variant = None; self.used_fst = Vec::new(); + self.used_multiwords = DefaultHashMap::new(); self.used_set = DefaultHashSet::new(); return; }; @@ -280,27 +295,65 @@ impl Spell { self.used_fst = used_fst_builder .into_inner() .expect("subset of valid fst must be valid."); + self.used_multiwords = self + .multiwords + .iter() + .map(|(key, value)| { + let value = value + .iter() + .filter_map(|(continuations, int)| { + if int.contains_variant(variant_index) { + Some(continuations) + } else { + None + } + }) + .cloned() + .collect(); + (key.to_owned(), value) + }) + .collect(); self.used_set = set; } - fn check_flat(&self, word: &str) -> bool { + fn check_word(&self, word: &str) -> bool { self.used_variant.is_none() || word.is_empty() - || word.chars().all(|x| x.is_punctuation() || x.is_numeric()) + || word + .chars() + .all(|x| x.is_symbol() || x.is_punctuation() || x.is_numeric()) || self.used_set.contains(word) || (is_title_case(word) - && self.check(&apply_to_first(word, |x| x.to_lowercase().collect()))) + && self.check_word(&apply_to_first(word, |x| x.to_lowercase().collect()))) } - pub fn check(&self, word: &str) -> bool { - self.check_flat(word) + fn check(&self, tokens: &[Token], correct_mask: &mut [bool]) { + let word = tokens[0].word.text.as_ref(); + + let word_is_correct = self.check_word(word) || (self.lang_options.split_hyphens && word .split(&['-', '\u{2010}', '\u{2011}'][..]) - .all(|x| self.check_flat(x))) + .all(|x| self.check_word(x))); + + correct_mask[0] = word_is_correct; + + if let Some(continuations) = self.used_multiwords.get(word) { + if let Some(matching_cont) = continuations.iter().find(|cont| { + (tokens.len() - 1) >= cont.len() + && cont + .iter() + .enumerate() + .all(|(i, x)| tokens[i + 1].word.text.as_ref() == x) + }) { + correct_mask[..1 + matching_cont.len()] + .iter_mut() + .for_each(|x| *x = true); + } + } } - pub fn search(&self, word: &str) -> Vec { + fn search(&self, word: &str) -> Vec { if let Some(candidate) = self.map.get(word) { return vec![Candidate { score: 0., // numerical values here do not matter since there is always exactly one candidate - ranking is irrelevant @@ -340,11 +393,15 @@ impl Spell { pub fn suggest(&self, tokens: &[Token]) -> Vec { let mut suggestions = Vec::new(); + let mut correct_mask = vec![false; tokens.len()]; - for token in tokens { + for (i, token) in tokens.iter().enumerate() { let text = token.word.text.as_ref(); - if token.ignore_spelling || self.check(text) { + if !correct_mask[i] { + self.check(&tokens[i..], &mut correct_mask[i..]); + } + if correct_mask[i] || token.ignore_spelling { continue; } diff --git a/nlprule/src/tokenizer/multiword.rs b/nlprule/src/tokenizer/multiword.rs index e53760e..2694524 100644 --- a/nlprule/src/tokenizer/multiword.rs +++ b/nlprule/src/tokenizer/multiword.rs @@ -8,7 +8,7 @@ use super::tag::Tagger; #[derive(Serialize, Deserialize)] pub(crate) struct MultiwordTaggerFields { - pub(crate) multiwords: Vec<(String, Option)>, + pub(crate) multiwords: Vec<(String, owned::PosId)>, } impl From for MultiwordTagger { @@ -35,7 +35,7 @@ impl From for MultiwordTagger { pub struct MultiwordTagger { #[serde(skip)] matcher: AhoCorasick, - multiwords: Vec<(String, Option)>, + multiwords: Vec<(String, owned::PosId)>, } impl MultiwordTagger { @@ -66,13 +66,10 @@ impl MultiwordTagger { let (word, pos) = &self.multiwords[m.pattern()]; // end index is inclusive for token in tokens[*start..(*end + 1)].iter_mut() { - if let Some(pos) = pos.as_ref() { - token.multiword_data = Some(WordData::new( - tagger.id_word(word.as_str().into()), - pos.as_ref_id(), - )); - } - + token.multiword_data = Some(WordData::new( + tagger.id_word(word.as_str().into()), + pos.as_ref_id(), + )); token.ignore_spelling = true; } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 40fceed..c88c019 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -713,20 +713,6 @@ impl PySpell { *guard.spell_mut().options_mut() = depythonize(options.to_object(py).as_ref(py))?; Ok(()) } - - fn check(&self, word: &str) -> bool { - self.rules.read().spell().check(word) - } - - fn search(&self, word: &str) -> Vec { - self.rules - .read() - .spell() - .search(word) - .into_iter() - .map(|candidate| PyCandidate { candidate }) - .collect() - } } /// The grammatical rules. From 79ac9266ad5058119799f377f59c23f196975bbd Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Mon, 15 Mar 2021 16:45:38 +0100 Subject: [PATCH 10/16] add whitelist support, add variant_checker struct for cleaner separation --- nlprule/src/spellcheck/mod.rs | 231 ++++++++++++++++++++-------------- 1 file changed, 139 insertions(+), 92 deletions(-) diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index 194d06a..2fd24ab 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -194,6 +194,99 @@ impl Variant { } } +#[derive(Debug, Clone)] +struct VariantChecker { + variant: Variant, + fst: Vec, + max_freq: usize, + multiwords: DefaultHashMap>>, + set: DefaultHashSet, + map: DefaultHashMap, + lang_options: SpellLangOptions, + options: SpellOptions, +} + +impl VariantChecker { + fn check_word(&self, word: &str, recurse: bool) -> bool { + word.is_empty() + || word + .chars() + .all(|x| x.is_symbol() || x.is_punctuation() || x.is_numeric()) + || self.set.contains(word) + || (recurse + && is_title_case(word) + && self.check_word(&apply_to_first(word, |x| x.to_lowercase().collect()), false)) + } + + fn check(&self, tokens: &[Token], correct_mask: &mut [bool]) { + let word = tokens[0].word.text.as_ref(); + let mut word_is_correct = self.check_word(word, true); + + if !word_is_correct && self.lang_options.split_hyphens { + let hyphens = &['-', '\u{2010}', '\u{2011}'][..]; + + if word.contains(hyphens) && word.split(hyphens).all(|x| self.check_word(x, true)) { + word_is_correct = true; + } + } + + correct_mask[0] = word_is_correct; + + if let Some(continuations) = self.multiwords.get(word) { + if let Some(matching_cont) = continuations.iter().find(|cont| { + // important: an empty continuation matches! so single words can also validly be part of `multiwords` + (tokens.len() - 1) >= cont.len() + && cont + .iter() + .enumerate() + .all(|(i, x)| tokens[i + 1].word.text.as_ref() == x) + }) { + correct_mask[..1 + matching_cont.len()] + .iter_mut() + .for_each(|x| *x = true); + } + } + } + + fn search(&self, word: &str) -> Vec { + if let Some(candidate) = self.map.get(word) { + return vec![Candidate { + score: 0., // numerical values here do not matter since there is always exactly one candidate - ranking is irrelevant + freq: 0, + distance: 0, + term: candidate.to_owned(), + }]; + } + + let used_fst = Map::new(self.fst.as_slice()).expect("used fst must be valid."); + let query = levenshtein::Levenshtein::new(word, self.options.max_distance, 2); + + let mut out = BinaryHeap::with_capacity(self.options.top_n); + + let mut stream = used_fst.search_with_state(query).into_stream(); + while let Some((k, v, s)) = stream.next() { + let state = s.expect("matching levenshtein state is always `Some`."); + assert!(state.dist() > 0); + + let id = SpellInt(v); + + let term = String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."); + out.push(Candidate { + distance: state.dist(), + freq: id.freq(), + term, + score: (self.options.max_distance - state.dist()) as f32 + + id.freq() as f32 / self.max_freq as f32 * self.options.freq_weight, + }); + if out.len() > self.options.top_n { + out.pop(); + } + } + + out.into_sorted_vec() + } +} + #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct Spell { fst: Vec, @@ -202,15 +295,9 @@ pub struct Spell { map: DefaultHashMap, lang_options: SpellLangOptions, options: SpellOptions, - // fields below are computed depending on the selected variant - #[serde(skip)] - used_variant: Option, - #[serde(skip)] - used_fst: Vec, - #[serde(skip)] - used_multiwords: DefaultHashMap>>, + // the `variant_checker` is computed based on the selected variant #[serde(skip)] - used_set: DefaultHashSet, + variant_checker: Option, } impl Spell { @@ -256,17 +343,14 @@ impl Spell { } pub(crate) fn ingest_options(&mut self) { - if self.used_variant == self.options.variant { + if self.variant_checker.as_ref().map(|x| &x.variant) == self.options.variant.as_ref() { return; } let variant = if let Some(variant) = self.options.variant.as_ref() { - variant + variant.clone() } else { - self.used_variant = None; - self.used_fst = Vec::new(); - self.used_multiwords = DefaultHashMap::new(); - self.used_set = DefaultHashSet::new(); + self.variant_checker = None; return; }; @@ -279,7 +363,7 @@ impl Spell { let variant_index = self .variants() .iter() - .position(|x| x == variant) + .position(|x| *x == variant) .expect("only valid variants are created."); while let Some((k, v)) = stream.next() { @@ -291,11 +375,10 @@ impl Spell { } } - self.used_variant = Some(variant.clone()); - self.used_fst = used_fst_builder + let fst = used_fst_builder .into_inner() .expect("subset of valid fst must be valid."); - self.used_multiwords = self + let mut multiwords: DefaultHashMap<_, _> = self .multiwords .iter() .map(|(key, value)| { @@ -313,85 +396,49 @@ impl Spell { (key.to_owned(), value) }) .collect(); - self.used_set = set; - } - - fn check_word(&self, word: &str) -> bool { - self.used_variant.is_none() - || word.is_empty() - || word - .chars() - .all(|x| x.is_symbol() || x.is_punctuation() || x.is_numeric()) - || self.used_set.contains(word) - || (is_title_case(word) - && self.check_word(&apply_to_first(word, |x| x.to_lowercase().collect()))) - } - - fn check(&self, tokens: &[Token], correct_mask: &mut [bool]) { - let word = tokens[0].word.text.as_ref(); - - let word_is_correct = self.check_word(word) - || (self.lang_options.split_hyphens - && word - .split(&['-', '\u{2010}', '\u{2011}'][..]) - .all(|x| self.check_word(x))); - - correct_mask[0] = word_is_correct; - - if let Some(continuations) = self.used_multiwords.get(word) { - if let Some(matching_cont) = continuations.iter().find(|cont| { - (tokens.len() - 1) >= cont.len() - && cont - .iter() - .enumerate() - .all(|(i, x)| tokens[i + 1].word.text.as_ref() == x) - }) { - correct_mask[..1 + matching_cont.len()] - .iter_mut() - .for_each(|x| *x = true); - } - } - } - - fn search(&self, word: &str) -> Vec { - if let Some(candidate) = self.map.get(word) { - return vec![Candidate { - score: 0., // numerical values here do not matter since there is always exactly one candidate - ranking is irrelevant - freq: 0, - distance: 0, - term: candidate.to_owned(), - }]; - } - - let used_fst = Map::new(self.used_fst.as_slice()).expect("used fst must be valid."); - let query = levenshtein::Levenshtein::new(word, self.options.max_distance, 2); - - let mut out = BinaryHeap::with_capacity(self.options.top_n); - let mut stream = used_fst.search_with_state(query).into_stream(); - while let Some((k, v, s)) = stream.next() { - let state = s.expect("matching levenshtein state is always `Some`."); - assert!(state.dist() > 0); - - let id = SpellInt(v); + for phrase in self + .options + .whitelist + .iter() + .map(|x| x.as_str()) + // for some important words we have to manually make sure they are ignored :) + .chain(vec!["nlprule", "Minixhofer"]) + { + let mut parts = phrase.trim().split_whitespace(); + + let first = if let Some(first) = parts.next() { + first + } else { + // silently ignore empty words + continue; + }; - let term = String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."); - out.push(Candidate { - distance: state.dist(), - freq: id.freq(), - term, - score: (self.options.max_distance - state.dist()) as f32 - + id.freq() as f32 / self.max_freq as f32 * self.options.freq_weight, - }); - if out.len() > self.options.top_n { - out.pop(); - } + multiwords + .entry(first.to_owned()) + .or_insert_with(Vec::new) + .push(parts.map(|x| x.to_owned()).collect()); } - out.into_sorted_vec() + self.variant_checker = Some(VariantChecker { + variant, + fst, + multiwords, + set, + map: self.map.clone(), + max_freq: self.max_freq, + options: self.options.clone(), + lang_options: self.lang_options.clone(), + }) } pub fn suggest(&self, tokens: &[Token]) -> Vec { + let variant_checker = if let Some(checker) = self.variant_checker.as_ref() { + checker + } else { + return Vec::new(); + }; + let mut suggestions = Vec::new(); let mut correct_mask = vec![false; tokens.len()]; @@ -399,13 +446,13 @@ impl Spell { let text = token.word.text.as_ref(); if !correct_mask[i] { - self.check(&tokens[i..], &mut correct_mask[i..]); + variant_checker.check(&tokens[i..], &mut correct_mask[i..]); } if correct_mask[i] || token.ignore_spelling { continue; } - let candidates = self.search(text); + let candidates = variant_checker.search(text); suggestions.push(Suggestion { source: "SPELLCHECK/SINGLE".into(), message: "Possibly misspelled word.".into(), From 9f15524ce16ff23e86bbf677bf7845a1d2549075 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Mon, 15 Mar 2021 17:35:37 +0100 Subject: [PATCH 11/16] extend distance to optimal string alignment distance --- nlprule/src/spellcheck/levenshtein.rs | 30 ++++++++++++++++++++++----- nlprule/src/spellcheck/mod.rs | 15 +++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/nlprule/src/spellcheck/levenshtein.rs b/nlprule/src/spellcheck/levenshtein.rs index 18f42ea..af66e3a 100644 --- a/nlprule/src/spellcheck/levenshtein.rs +++ b/nlprule/src/spellcheck/levenshtein.rs @@ -9,6 +9,9 @@ use std::{ pub struct LevenshteinState { dist: usize, n: usize, + // to compute the next row of the matrix, we also need the row two rows up for transposes + prev_row: Option>, + prev_byte: u8, row: Vec, hash: u64, } @@ -43,6 +46,8 @@ impl<'a> Automaton for Levenshtein<'a> { Some(LevenshteinState { dist: self.query.len(), n: 0, + prev_row: None, + prev_byte: 0, row: (0..=self.query.len()).collect(), hash: FnvHasher::default().finish(), }) @@ -64,20 +69,33 @@ impl<'a> Automaton for Levenshtein<'a> { byte.hash(&mut next_hasher); let next_hash = next_hasher.finish(); - let prev_row = &state.row; + let row = &state.row; let mut next_row = state.row.to_vec(); next_row[0] = state.n + 1; for i in 1..next_row.len() { - let cost = if byte == self.query[i - 1] { - prev_row[i - 1] + let mut cost = if byte == self.query[i - 1] { + row[i - 1] } else { min( - next_row[i - 1] + 1, - min(prev_row[i - 1] + 1, prev_row[i] + 1), + next_row[i - 1] + 1, // deletes + min( + row[i - 1] + 1, // inserts + row[i] + 1, // substitutes + ), ) }; + + if i > 1 { + // transposes + if let Some(prev_row) = state.prev_row.as_ref() { + if byte == self.query[i - 2] && state.prev_byte == self.query[i - 1] { + cost = min(cost, prev_row[i - 2] + 1); + } + } + } + next_row[i] = cost; } @@ -106,6 +124,8 @@ impl<'a> Automaton for Levenshtein<'a> { Some(LevenshteinState { dist: next_row[self.query.len()], n: state.n + 1, + prev_row: Some(row.clone()), + prev_byte: byte, row: next_row, hash: next_hash, }) diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spellcheck/mod.rs index 2fd24ab..02fe4bc 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spellcheck/mod.rs @@ -248,14 +248,9 @@ impl VariantChecker { } } - fn search(&self, word: &str) -> Vec { + fn search(&self, word: &str) -> Vec { if let Some(candidate) = self.map.get(word) { - return vec![Candidate { - score: 0., // numerical values here do not matter since there is always exactly one candidate - ranking is irrelevant - freq: 0, - distance: 0, - term: candidate.to_owned(), - }]; + return vec![candidate.to_owned()]; } let used_fst = Map::new(self.fst.as_slice()).expect("used fst must be valid."); @@ -283,7 +278,8 @@ impl VariantChecker { } } - out.into_sorted_vec() + // `into_iter_sorted` is unstable - see https://github.com/rust-lang/rust/issues/59278 + out.into_sorted_vec().into_iter().map(|x| x.term).collect() } } @@ -452,13 +448,12 @@ impl Spell { continue; } - let candidates = variant_checker.search(text); suggestions.push(Suggestion { source: "SPELLCHECK/SINGLE".into(), message: "Possibly misspelled word.".into(), start: token.char_span.0, end: token.char_span.1, - replacements: candidates.into_iter().map(|x| x.term).collect(), + replacements: variant_checker.search(text), }); } From 57c06e2ed19219319d0c7f01b1690bf2ecad3d93 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 16 Mar 2021 10:38:06 +0100 Subject: [PATCH 12/16] improve docs --- nlprule/src/compile/impls.rs | 2 +- nlprule/src/compile/mod.rs | 2 +- nlprule/src/compile/utils.rs | 4 +- nlprule/src/lib.rs | 2 +- nlprule/src/rules.rs | 6 +- .../src/{spellcheck => spell}/levenshtein.rs | 0 nlprule/src/{spellcheck => spell}/mod.rs | 145 +++++++++++------- nlprule/src/tokenizer.rs | 2 +- python/src/lib.rs | 30 +--- 9 files changed, 100 insertions(+), 93 deletions(-) rename nlprule/src/{spellcheck => spell}/levenshtein.rs (100%) rename nlprule/src/{spellcheck => spell}/mod.rs (67%) diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 1f1f7e9..1d8b635 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -24,7 +24,7 @@ use crate::{ DisambiguationRule, MatchGraph, Rule, }, rules::{Rules, RulesLangOptions}, - spellcheck::{Spell, SpellInt, SpellLangOptions}, + spell::{Spell, SpellInt, SpellLangOptions}, tokenizer::{ chunk, multiword::{MultiwordTagger, MultiwordTaggerFields}, diff --git a/nlprule/src/compile/mod.rs b/nlprule/src/compile/mod.rs index 3f121a8..1654a19 100644 --- a/nlprule/src/compile/mod.rs +++ b/nlprule/src/compile/mod.rs @@ -14,7 +14,7 @@ use std::{ use crate::{ rules::Rules, - spellcheck::Spell, + spell::Spell, tokenizer::{chunk::Chunker, multiword::MultiwordTagger, tag::Tagger, Tokenizer}, types::*, }; diff --git a/nlprule/src/compile/utils.rs b/nlprule/src/compile/utils.rs index 3a9d8aa..e86d30e 100644 --- a/nlprule/src/compile/utils.rs +++ b/nlprule/src/compile/utils.rs @@ -1,6 +1,4 @@ -use crate::{ - rules::RulesLangOptions, spellcheck::SpellLangOptions, tokenizer::TokenizerLangOptions, -}; +use crate::{rules::RulesLangOptions, spell::SpellLangOptions, tokenizer::TokenizerLangOptions}; use crate::{tokenizer::tag::TaggerLangOptions, types::*}; use lazy_static::lazy_static; diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 3cb4d6a..32750b7 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -63,7 +63,7 @@ pub mod compile; mod filter; pub mod rule; pub mod rules; -pub mod spellcheck; +pub mod spell; pub mod tokenizer; pub mod types; pub(crate) mod utils; diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index b57a056..63d7cb9 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -2,7 +2,7 @@ use crate::{rule::id::Selector, tokenizer::Tokenizer}; use crate::{rule::Rule, Error}; -use crate::{spellcheck::Spell, types::*, utils::parallelism::MaybeParallelRefIterator}; +use crate::{spell::Spell, types::*, utils::parallelism::MaybeParallelRefIterator}; use fs_err::File; use serde::{Deserialize, Serialize}; use std::{ @@ -58,7 +58,7 @@ pub struct Rules { } impl Rules { - /// TODO + /// Serializes the rules set to a writer. pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { // TODO: the .clone() here could be avoided let fields: RulesFields = self.clone().into(); @@ -88,10 +88,12 @@ impl Rules { Self::from_reader(reader, tokenizer) } + /// Gets the spellchecker associated with this rules set. The spellchecker always exists, even if spellchecking is disabled (default). pub fn spell(&self) -> &Spell { &self.spell } + /// Mutably gets the spellchecker. pub fn spell_mut(&mut self) -> &mut Spell { &mut self.spell } diff --git a/nlprule/src/spellcheck/levenshtein.rs b/nlprule/src/spell/levenshtein.rs similarity index 100% rename from nlprule/src/spellcheck/levenshtein.rs rename to nlprule/src/spell/levenshtein.rs diff --git a/nlprule/src/spellcheck/mod.rs b/nlprule/src/spell/mod.rs similarity index 67% rename from nlprule/src/spellcheck/mod.rs rename to nlprule/src/spell/mod.rs index 02fe4bc..45ba5cf 100644 --- a/nlprule/src/spellcheck/mod.rs +++ b/nlprule/src/spell/mod.rs @@ -1,3 +1,4 @@ +//! Structures and implementations related to spellchecking. use fst::{IntoStreamer, Map, MapBuilder, Streamer}; use serde::{Deserialize, Serialize}; use std::{ @@ -19,6 +20,10 @@ mod spell_int { use serde::{Deserialize, Serialize}; + /// Encodes information about a valid word in a `u64` for storage as value in an FST. + /// Currently: + /// - the bottom 8 bits encode the frequency + /// - the other 56 bits act as flags for the variants e.g. bit 10 and 12 are set if the word exists in the the second and fourth variant. #[derive(Debug, Clone, Default, Copy, Serialize, Deserialize)] pub(crate) struct SpellInt(pub(super) u64); @@ -91,7 +96,7 @@ mod spell_int { pub(crate) use spell_int::SpellInt; #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] -pub struct Candidate { +struct Candidate { score: f32, distance: usize, freq: usize, @@ -110,36 +115,27 @@ impl Ord for Candidate { } } -impl Candidate { - pub fn score(&self) -> f32 { - self.score - } - - pub fn freq(&self) -> usize { - self.freq - } - - pub fn distance(&self) -> usize { - self.distance - } - - pub fn term(&self) -> &str { - self.term.as_str() - } -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[serde(default)] -/// TODO +/// Options to configure the spellchecker. pub struct SpellOptions { + /// The language variant to use. Setting this to `None` disables spellchecking. pub variant: Option, + /// The maximum edit distance to consider for corrections. Currently Optimal String Alignment distance is used. pub max_distance: usize, + /// A fixed prefix length for which to consider only edits with a distance of 1. This speeds up the search by pruning the tree early. pub prefix_length: usize, + /// How high to weigh the frequency of a word compared to the edit distance when ranking correction candidates. + /// Setting this to `x` makes the frequency make a difference of at most `x` edit distance. pub freq_weight: f32, + /// The maximum number of correction candidates to return. pub top_n: usize, + /// A set of words to ignore. Can also contain phrases delimited by a space. pub whitelist: HashSet, } +/// A guard around the [SpellOptions]. Makes sure the spellchecker is updated once this is dropped. +/// Implements `Deref` and `DerefMut` to the [SpellOptions]. pub struct SpellOptionsGuard<'a> { spell: &'a mut Spell, } @@ -166,7 +162,7 @@ impl<'a> Drop for SpellOptionsGuard<'a> { #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] pub(crate) struct SpellLangOptions { - /// Variants of the language (e. g. "en_US", "en_GB") to consider for spellchecking. + /// Variants of the language (e.g. "en_US", "en_GB") to consider for spellchecking. pub variants: Vec, pub split_hyphens: bool, } @@ -184,16 +180,19 @@ impl Default for SpellOptions { } } +/// A valid language variant. Obtained by [Spell::variant]. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] #[serde(transparent)] pub struct Variant(String); impl Variant { + /// Gets the language code of this variant. pub fn as_str(&self) -> &str { self.0.as_str() } } +/// Spellchecker logic for one variant. Does the actual work. #[derive(Debug, Clone)] struct VariantChecker { variant: Variant, @@ -207,22 +206,32 @@ struct VariantChecker { } impl VariantChecker { + /// Checks the validity of one word. + /// NB: The ordering of this chain of `||` operators is somewhat nontrivial. Could potentially be improved by benchmarking. + /// If this is true, the token is always correct. The converse is not true because e.g. multiwords are checked separately. fn check_word(&self, word: &str, recurse: bool) -> bool { word.is_empty() + || self.set.contains(word) || word .chars() .all(|x| x.is_symbol() || x.is_punctuation() || x.is_numeric()) - || self.set.contains(word) || (recurse + // for title case words, it is enough if the lowercase variant is known. + // it is possible that `is_title_case` is still true for word where `.to_lowercase()` was called so we need a `recurse` parameter. && is_title_case(word) && self.check_word(&apply_to_first(word, |x| x.to_lowercase().collect()), false)) } + /// Populates `correct_mask` according to the correctness of the given zeroth token. + /// - `correct_mask[0]` is `true` if the zeroth token is correct, `false` if it is not correct. + /// - Indices `1..n` of `correct_mask` are `true` if the `n`th token is also definitely correct. + /// If they are `false`, they need to be checked separately. fn check(&self, tokens: &[Token], correct_mask: &mut [bool]) { let word = tokens[0].word.text.as_ref(); let mut word_is_correct = self.check_word(word, true); if !word_is_correct && self.lang_options.split_hyphens { + // there exist multiple valid hyphens, see https://jkorpela.fi/dashes.html let hyphens = &['-', '\u{2010}', '\u{2011}'][..]; if word.contains(hyphens) && word.split(hyphens).all(|x| self.check_word(x, true)) { @@ -261,7 +270,6 @@ impl VariantChecker { let mut stream = used_fst.search_with_state(query).into_stream(); while let Some((k, v, s)) = stream.next() { let state = s.expect("matching levenshtein state is always `Some`."); - assert!(state.dist() > 0); let id = SpellInt(v); @@ -284,14 +292,19 @@ impl VariantChecker { } #[derive(Debug, Clone, Default, Serialize, Deserialize)] +/// A spellchecker implementing the algorithm described in [Error-tolerant Finite State Recognition](https://www.aclweb.org/anthology/1995.iwpt-1.24/) with some extensions. pub struct Spell { + /// An FST mapping valid words (always single tokens!) to a [SpellInt]. fst: Vec, + /// Known *multiwords* i. e. phrases. Can also validly contain single words if they should not be part of the FST (e.g. words in the whitelist). multiwords: DefaultHashMap, SpellInt)>>, + /// The maximum occured word frequency. Used to normalize. max_freq: usize, + /// A map of `wrong->right`. `wrong` must always be exactly one token. map: DefaultHashMap, lang_options: SpellLangOptions, options: SpellOptions, - // the `variant_checker` is computed based on the selected variant + /// The structure containing the actual spellchecking logic. Computed based on the selected variant. #[serde(skip)] variant_checker: Option, } @@ -317,18 +330,24 @@ impl Spell { spell } + /// Gets the options. pub fn options(&self) -> &SpellOptions { &self.options } + /// Mutably gets the options. pub fn options_mut(&mut self) -> SpellOptionsGuard { SpellOptionsGuard { spell: self } } + /// Returns all known variants. pub fn variants(&self) -> &[Variant] { self.lang_options.variants.as_slice() } + /// Returns the variant for a language code e.g. `"en_GB"`. + /// # Errors + /// - If no variant exists for the language code. pub fn variant(&self, variant: &str) -> Result { self.lang_options .variants @@ -338,11 +357,7 @@ impl Spell { .ok_or_else(|| Error::UnknownVariant(variant.to_owned())) } - pub(crate) fn ingest_options(&mut self) { - if self.variant_checker.as_ref().map(|x| &x.variant) == self.options.variant.as_ref() { - return; - } - + fn ingest_options(&mut self) { let variant = if let Some(variant) = self.options.variant.as_ref() { variant.clone() } else { @@ -350,30 +365,55 @@ impl Spell { return; }; - let mut used_fst_builder = MapBuilder::memory(); - let mut set = DefaultHashSet::new(); - - let fst = Map::new(&self.fst).expect("serialized fst must be valid."); - let mut stream = fst.into_stream(); - let variant_index = self .variants() .iter() .position(|x| *x == variant) .expect("only valid variants are created."); - while let Some((k, v)) = stream.next() { - if SpellInt(v).contains_variant(variant_index) { - set.insert(String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8.")); - used_fst_builder - .insert(k, v) - .expect("fst stream returns values in lexicographic order."); + let mut checker = match self.variant_checker.take() { + // if the variant checker exists and uses the correct variant, we don't need to rebuild + Some(checker) if checker.variant == variant => checker, + _ => { + let mut used_fst_builder = MapBuilder::memory(); + let mut set = DefaultHashSet::new(); + + let fst = Map::new(&self.fst).expect("serialized fst must be valid."); + let mut stream = fst.into_stream(); + + while let Some((k, v)) = stream.next() { + if SpellInt(v).contains_variant(variant_index) { + set.insert( + String::from_utf8(k.to_vec()).expect("fst keys must be valid utf-8."), + ); + used_fst_builder + .insert(k, v) + .expect("fst stream returns values in lexicographic order."); + } + } + + let fst = used_fst_builder + .into_inner() + .expect("subset of valid fst must be valid."); + + VariantChecker { + variant, + fst, + multiwords: DefaultHashMap::new(), + set, + map: self.map.clone(), + max_freq: self.max_freq, + options: self.options.clone(), + lang_options: self.lang_options.clone(), + } } - } + }; + + // `multiwords` depend on the whitelist. For convenience we always rebuild this. + // the whitelist could be separated into a new structure for a speedup. + // We can revisit this if performance becomes an issue, it should still be quite fast as implemented now. - let fst = used_fst_builder - .into_inner() - .expect("subset of valid fst must be valid."); + // selects only the multiwords which exist for the selected variant let mut multiwords: DefaultHashMap<_, _> = self .multiwords .iter() @@ -393,6 +433,8 @@ impl Spell { }) .collect(); + // adds words from the user-set whitelist + // careful: words in the `whitelist` are set by the user, so this must never fail! for phrase in self .options .whitelist @@ -416,18 +458,11 @@ impl Spell { .push(parts.map(|x| x.to_owned()).collect()); } - self.variant_checker = Some(VariantChecker { - variant, - fst, - multiwords, - set, - map: self.map.clone(), - max_freq: self.max_freq, - options: self.options.clone(), - lang_options: self.lang_options.clone(), - }) + checker.multiwords = multiwords; + self.variant_checker = Some(checker); } + /// Runs the spellchecking algorithm on all tokens and returns suggestions. pub fn suggest(&self, tokens: &[Token]) -> Vec { let variant_checker = if let Some(checker) = self.variant_checker.as_ref() { checker diff --git a/nlprule/src/tokenizer.rs b/nlprule/src/tokenizer.rs index 8a3d2c5..de5834a 100644 --- a/nlprule/src/tokenizer.rs +++ b/nlprule/src/tokenizer.rs @@ -128,7 +128,7 @@ impl Tokenizer { Ok(bincode::deserialize_from(reader)?) } - /// TODO + /// Serializes the tokenizer to a writer. pub fn to_writer(&self, writer: &mut W) -> Result<(), Error> { Ok(bincode::serialize_into(writer, &self)?) } diff --git a/python/src/lib.rs b/python/src/lib.rs index c88c019..c0aba6d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -2,7 +2,7 @@ use flate2::read::GzDecoder; use nlprule::{ rule::{id::Selector, Example, Rule}, rules::{apply_suggestions, Rules}, - spellcheck::{Candidate, Spell}, + spell::Spell, tokenizer::tag::Tagger, tokenizer::Tokenizer, types::*, @@ -654,34 +654,6 @@ impl PySpellOptions { } } -#[pyclass(name = "Candidate", module = "nlprule.spell")] -struct PyCandidate { - candidate: Candidate, -} - -#[pymethods] -impl PyCandidate { - #[getter] - fn score(&self) -> f32 { - self.candidate.score() - } - - #[getter] - fn distance(&self) -> usize { - self.candidate.distance() - } - - #[getter] - fn freq(&self) -> usize { - self.candidate.freq() - } - - #[getter] - fn term(&self) -> &str { - self.candidate.term() - } -} - #[pyclass(name = "Spell", module = "nlprule.spell")] struct PySpell { rules: Arc>, From 0a738db03e98144f4b4c177eb398cdd4bdfa2a04 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 16 Mar 2021 18:27:39 +0100 Subject: [PATCH 13/16] reduce repetition, fix warnings --- nlprule/src/compile/impls.rs | 179 +++++++++++++++++++++-------------- nlprule/src/lib.rs | 4 +- nlprule/src/spell/mod.rs | 50 ++++------ nlprule/src/tokenizer/tag.rs | 1 + 4 files changed, 129 insertions(+), 105 deletions(-) diff --git a/nlprule/src/compile/impls.rs b/nlprule/src/compile/impls.rs index 1d8b635..c1b41bd 100644 --- a/nlprule/src/compile/impls.rs +++ b/nlprule/src/compile/impls.rs @@ -38,6 +38,92 @@ use crate::{ use super::{parse_structure::BuildInfo, Error}; impl Spell { + fn new( + fst: Vec, + multiwords: DefaultHashMap, SpellInt)>>, + max_freq: usize, + map: DefaultHashMap, + lang_options: SpellLangOptions, + ) -> Self { + let mut spell = Spell { + fst, + multiwords, + max_freq, + map, + lang_options, + ..Default::default() + }; + spell.ingest_options(); + spell + } + + #[allow(clippy::clippy::too_many_arguments)] // lots of arguments here but not easily avoidable + fn add_line( + word: &str, + freq: usize, + words: &mut DefaultHashMap, + multiwords: &mut DefaultHashMap, SpellInt)>>, + variant_index: usize, + // in some LT lists an underline denotes a prefix such that e.g. "hin_reiten" also adds "hingeritten" + underline_denotes_prefix: bool, + build_info: &mut BuildInfo, + tokenizer: &Tokenizer, + ) { + let tokens = tokenizer.get_token_strs(word); + + if tokens.len() > 1 { + assert!(!(underline_denotes_prefix && word.contains('_'))); // not supported in entries spanning multiple tokens + + let mut int = SpellInt::default(); + int.add_variant(variant_index); + + // we do not add the frequency for multiwords - they are not used for suggestions, just to check validity + + multiwords + .entry(tokens[0].to_owned()) + .or_insert_with(Vec::new) + .push(( + tokens[1..] + .iter() + .filter(|x| !x.trim().is_empty()) + .map(|x| (*x).to_owned()) + .collect(), + int, + )); + } else if word.contains('_') && underline_denotes_prefix { + assert!(!word.contains('\\')); // escaped underlines are not supported + let mut parts = word.split('_'); + + let prefix = parts.next().unwrap(); + let suffix = parts.next().unwrap(); + + // this will presumably always be covered by the extra suffixes, but add it just to make sure + let value = words.entry(format!("{}{}", prefix, suffix)).or_default(); + value.add_variant(variant_index); + value.update_freq(freq); + + let replacer = PosReplacer { + matcher: PosMatcher::new( + Matcher::new_regex(Regex::new("^VER:.*".into()), false, true), + build_info, + ), + }; + + for new_suffix in replacer.apply(suffix, tokenizer) { + let new_word = format!("{}{}", prefix, new_suffix); + + let value = words.entry(new_word).or_default(); + value.add_variant(variant_index); + value.update_freq(freq); + } + } else { + let value = words.entry((*word).to_owned()).or_default(); + + value.update_freq(freq); + value.add_variant(variant_index); + } + } + pub(in crate::compile) fn from_dumps( spell_dir_path: impl AsRef, map_path: impl AsRef, @@ -68,33 +154,18 @@ impl Spell { [freq, word] => { // frequency is denoted as letters from A to Z in LanguageTool where A is the least frequent. let freq = freq.chars().next().expect("freq must have one char - would not have been yielded by split_whitespace otherwise.") as usize - 'A' as usize; - let value = words.entry((*word).to_owned()).or_default(); - - let tokens = tokenizer.get_token_strs(word); - - if tokens.len() > 1 { - let mut int = SpellInt::default(); - int.add_variant(i); - - multiwords - .entry(tokens[0].to_owned()) - .or_insert_with(Vec::new) - .push(( - tokens[1..] - .iter() - .filter(|x| !x.trim().is_empty()) - .map(|x| (*x).to_owned()) - .collect(), - int, - )); - - continue; - } + Spell::add_line( + word, + freq, + &mut words, + &mut multiwords, + i, + false, + build_info, + tokenizer, + ); max_freq = cmp::max(max_freq, freq); - - value.update_freq(freq); - value.add_variant(i); } _ => continue, } @@ -108,54 +179,16 @@ impl Spell { let line = line?; let word = line.trim(); - let tokens = tokenizer.get_token_strs(word); - - if tokens.len() > 1 { - let mut int = SpellInt::default(); - int.add_variant(i); - multiwords - .entry(tokens[0].to_owned()) - .or_insert_with(Vec::new) - .push(( - tokens[1..] - .iter() - .filter(|x| !x.trim().is_empty()) - .map(|x| (*x).to_owned()) - .collect(), - int, - )); - - continue; - } - - if word.contains('_') { - assert!(!word.contains('\\')); // escaped underlines are not supported - let mut parts = word.split('_'); - - let prefix = parts.next().unwrap(); - let suffix = parts.next().unwrap(); - - // this will presumably always be covered by the extra suffixes, but add it just to make sure - words - .entry(format!("{}{}", prefix, suffix)) - .or_default() - .add_variant(i); - - let replacer = PosReplacer { - matcher: PosMatcher::new( - Matcher::new_regex(Regex::new("^VER:.*".into()), false, true), - build_info, - ), - }; - - for new_suffix in replacer.apply(suffix, tokenizer) { - let new_word = format!("{}{}", prefix, new_suffix); - - words.entry(new_word).or_default().add_variant(i); - } - } - - words.entry((*word).to_owned()).or_default().add_variant(i); + Spell::add_line( + word, + 0, + &mut words, + &mut multiwords, + i, + true, + build_info, + tokenizer, + ); } } let mut words: Vec<_> = words diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 32750b7..60536fd 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -78,8 +78,8 @@ pub enum Error { Io(#[from] io::Error), #[error("deserialization error: {0}")] Deserialization(#[from] bincode::Error), - #[error("unknown language variant: {0}")] - UnknownVariant(String), + #[error("unknown language variant: \"{0}\". known variants are: {1:?}.")] + UnknownVariant(String, Vec), } /// Gets the canonical filename for the tokenizer binary for a language code in ISO 639-1 (two-letter) format. diff --git a/nlprule/src/spell/mod.rs b/nlprule/src/spell/mod.rs index 45ba5cf..f487406 100644 --- a/nlprule/src/spell/mod.rs +++ b/nlprule/src/spell/mod.rs @@ -33,6 +33,7 @@ mod spell_int { std::mem::size_of::() * 8 } + #[allow(dead_code)] // some methods are only needed for compilation - kept here for clarity impl SpellInt { pub fn as_u64(&self) -> u64 { self.0 @@ -194,7 +195,7 @@ impl Variant { /// Spellchecker logic for one variant. Does the actual work. #[derive(Debug, Clone)] -struct VariantChecker { +pub(crate) struct VariantChecker { variant: Variant, fst: Vec, max_freq: usize, @@ -295,41 +296,21 @@ impl VariantChecker { /// A spellchecker implementing the algorithm described in [Error-tolerant Finite State Recognition](https://www.aclweb.org/anthology/1995.iwpt-1.24/) with some extensions. pub struct Spell { /// An FST mapping valid words (always single tokens!) to a [SpellInt]. - fst: Vec, + pub(crate) fst: Vec, /// Known *multiwords* i. e. phrases. Can also validly contain single words if they should not be part of the FST (e.g. words in the whitelist). - multiwords: DefaultHashMap, SpellInt)>>, + pub(crate) multiwords: DefaultHashMap, SpellInt)>>, /// The maximum occured word frequency. Used to normalize. - max_freq: usize, + pub(crate) max_freq: usize, /// A map of `wrong->right`. `wrong` must always be exactly one token. - map: DefaultHashMap, - lang_options: SpellLangOptions, - options: SpellOptions, + pub(crate) map: DefaultHashMap, + pub(crate) lang_options: SpellLangOptions, + pub(crate) options: SpellOptions, /// The structure containing the actual spellchecking logic. Computed based on the selected variant. #[serde(skip)] - variant_checker: Option, + pub(crate) variant_checker: Option, } impl Spell { - pub(crate) fn new( - fst: Vec, - multiwords: DefaultHashMap, SpellInt)>>, - max_freq: usize, - map: DefaultHashMap, - lang_options: SpellLangOptions, - ) -> Self { - let mut spell = Spell { - fst, - multiwords, - max_freq, - map, - lang_options, - options: SpellOptions::default(), - ..Default::default() - }; - spell.ingest_options(); - spell - } - /// Gets the options. pub fn options(&self) -> &SpellOptions { &self.options @@ -354,10 +335,19 @@ impl Spell { .iter() .find(|x| x.as_str() == variant) .cloned() - .ok_or_else(|| Error::UnknownVariant(variant.to_owned())) + .ok_or_else(|| { + Error::UnknownVariant( + variant.to_owned(), + self.lang_options + .variants + .iter() + .map(|x| x.as_str().to_owned()) + .collect(), + ) + }) } - fn ingest_options(&mut self) { + pub(crate) fn ingest_options(&mut self) { let variant = if let Some(variant) = self.options.variant.as_ref() { variant.clone() } else { diff --git a/nlprule/src/tokenizer/tag.rs b/nlprule/src/tokenizer/tag.rs index 2693541..dd3879c 100644 --- a/nlprule/src/tokenizer/tag.rs +++ b/nlprule/src/tokenizer/tag.rs @@ -199,6 +199,7 @@ impl Tagger { &self.tag_store } + #[allow(dead_code)] // used by compile module pub(crate) fn word_store(&self) -> &BiMap { &self.word_store } From f474d91b09b070b74adb24874121a22a7674998a Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 16 Mar 2021 19:51:12 +0100 Subject: [PATCH 14/16] fix test, qol improvement in dev scripts --- nlprule/src/spell/mod.rs | 2 +- scripts/build_and_test.sh | 23 ++++++++++++++++++++--- scripts/maturin.sh | 20 +++++++++++++++----- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/nlprule/src/spell/mod.rs b/nlprule/src/spell/mod.rs index f487406..86de746 100644 --- a/nlprule/src/spell/mod.rs +++ b/nlprule/src/spell/mod.rs @@ -89,7 +89,7 @@ mod spell_int { assert!(int.contains_variant(1)); assert!(int.contains_variant(10)); assert!(!int.contains_variant(2)); - assert!(int.freq() == 10); + assert!(int.freq() == 100); } } } diff --git a/scripts/build_and_test.sh b/scripts/build_and_test.sh index 6f43f9f..d060156 100755 --- a/scripts/build_and_test.sh +++ b/scripts/build_and_test.sh @@ -1,6 +1,23 @@ # this script assumes the build directories are in data/ # only for convenience mkdir -p storage -RUST_LOG=INFO cargo run --all-features --bin compile -- --build-dir data/$1 --tokenizer-out storage/$1_tokenizer.bin --rules-out storage/$1_rules.bin -RUST_LOG=WARN cargo run --all-features --bin test_disambiguation -- --tokenizer storage/$1_tokenizer.bin -RUST_LOG=WARN cargo run --all-features --bin test -- --tokenizer storage/$1_tokenizer.bin --rules storage/$1_rules.bin \ No newline at end of file + +# x-- => only compile +# -xx => test_disambiguation and test +# xxx or flags not set => everything +flags=${2:-"xxx"} + +if [ "${flags:0:1}" == "x" ] +then + RUST_LOG=INFO cargo run --all-features --bin compile -- --build-dir data/$1 --tokenizer-out storage/$1_tokenizer.bin --rules-out storage/$1_rules.bin +fi + +if [ "${flags:1:1}" == "x" ] +then + RUST_LOG=WARN cargo run --all-features --bin test_disambiguation -- --tokenizer storage/$1_tokenizer.bin +fi + +if [ "${flags:2:1}" == "x" ] +then + RUST_LOG=WARN cargo run --all-features --bin test -- --tokenizer storage/$1_tokenizer.bin --rules storage/$1_rules.bin +fi \ No newline at end of file diff --git a/scripts/maturin.sh b/scripts/maturin.sh index 7fa1844..348c684 100755 --- a/scripts/maturin.sh +++ b/scripts/maturin.sh @@ -22,13 +22,23 @@ build_change build/Cargo.toml build_change Cargo.toml cd python + +trap cleanup INT + +function cleanup() { + # this is a bit hacky, assume we are in python/ dir + cd .. + + mv python/.Cargo.toml.bak python/Cargo.toml + mv nlprule/.Cargo.toml.bak nlprule/Cargo.toml + mv build/.Cargo.toml.bak build/Cargo.toml + mv .Cargo.toml.bak Cargo.toml + exit +} + maturin $@ exit_code=$? -cd .. -mv python/.Cargo.toml.bak python/Cargo.toml -mv nlprule/.Cargo.toml.bak nlprule/Cargo.toml -mv build/.Cargo.toml.bak build/Cargo.toml -mv .Cargo.toml.bak Cargo.toml +cleanup() exit $exit_code \ No newline at end of file From f1fdee76d830da32076868d2ab722fa23f978031 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 16 Mar 2021 20:54:42 +0100 Subject: [PATCH 15/16] fix doctests --- nlprule/src/lib.rs | 39 ++++++++++++++++++++++++++++++++------- nlprule/src/rules.rs | 21 +++++++++++++-------- nlprule/tests/tests.rs | 2 +- python/src/lib.rs | 5 +++-- 4 files changed, 49 insertions(+), 18 deletions(-) diff --git a/nlprule/src/lib.rs b/nlprule/src/lib.rs index 60536fd..6b2b047 100644 --- a/nlprule/src/lib.rs +++ b/nlprule/src/lib.rs @@ -5,32 +5,35 @@ //! - A [Tokenizer][tokenizer::Tokenizer] to split a text into tokens and analyze it by chunking, lemmatizing and part-of-speech tagging. Can also be used independently of the grammatical rules. //! - A [Rules][rules::Rules] structure containing a set of grammatical error correction rules. //! -//! # Example: correct a text +//! # Examples //! +//! Correct a text: //! ```no_run //! use nlprule::{Tokenizer, Rules}; //! //! let tokenizer = Tokenizer::new("path/to/en_tokenizer.bin")?; -//! let rules = Rules::new("path/to/en_rules.bin")?; +//! let mut rules = Rules::new("path/to/en_rules.bin", tokenizer.into())?; +//! // enable spellchecking +//! rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB")?); //! //! assert_eq!( -//! rules.correct("She was not been here since Monday.", &tokenizer), -//! String::from("She was not here since Monday.") +//! rules.correct("I belive she was not been here since Monday."), +//! String::from("I believe she was not here since Monday.") //! ); //! # Ok::<(), nlprule::Error>(()) //! ``` //! -//! # Example: get suggestions and correct a text +//! Get suggestions and correct a text: //! //! ```no_run //! use nlprule::{Tokenizer, Rules, types::Suggestion, rules::apply_suggestions}; //! //! let tokenizer = Tokenizer::new("path/to/en_tokenizer.bin")?; -//! let rules = Rules::new("path/to/en_rules.bin")?; +//! let rules = Rules::new("path/to/en_rules.bin", tokenizer.into())?; //! //! let text = "She was not been here since Monday."; //! -//! let suggestions = rules.suggest(text, &tokenizer); +//! let suggestions = rules.suggest(text); //! assert_eq!( //! suggestions, //! vec![Suggestion { @@ -48,6 +51,28 @@ //! # Ok::<(), nlprule::Error>(()) //! ``` //! +//! Tokenize & analyze a text: +//! +//! ```no_run +//! use nlprule::Tokenizer; +//! +//! let tokenizer = Tokenizer::new("path/to/en_tokenizer.bin")?; +//! +//! let text = "A brief example is shown."; +//! +//! // returns a vector over sentences +//! // we assume this is one sentence so we take the first element +//! let tokens = tokenizer.pipe(text).remove(0); +//! +//! println!("{:#?}", tokens); +//! // token at index zero is the special SENT_START token - generally not interesting +//! assert_eq!(tokens[2].word.text.as_ref(), "brief"); +//! assert_eq!(tokens[2].word.tags[0].pos.as_ref(), "JJ"); +//! assert_eq!(tokens[2].chunks, vec!["I-NP-singular"]); +//! // some other information like char / byte span, lemmas etc. is also set! +//! # Ok::<(), nlprule::Error>(()) +//! ``` +//! --- //! Binaries are distributed with [Github releases](https://github.com/bminixhofer/nlprule/releases). //! //! # The 't lifetime diff --git a/nlprule/src/rules.rs b/nlprule/src/rules.rs index 63d7cb9..a20e673 100644 --- a/nlprule/src/rules.rs +++ b/nlprule/src/rules.rs @@ -202,19 +202,24 @@ impl Rules { } /// Correct a text by applying suggestions to it. -/// In the case of multiple possible replacements, always chooses the first one. +/// - In case of multiple possible replacements, always chooses the first one. +/// - In case of a suggestion without any replacements, ignores the suggestion. pub fn apply_suggestions(text: &str, suggestions: &[Suggestion]) -> String { let mut offset: isize = 0; let mut chars: Vec<_> = text.chars().collect(); for suggestion in suggestions { - let replacement: Vec<_> = suggestion.replacements[0].chars().collect(); - chars.splice( - (suggestion.start as isize + offset) as usize - ..(suggestion.end as isize + offset) as usize, - replacement.iter().cloned(), - ); - offset = offset + replacement.len() as isize - (suggestion.end - suggestion.start) as isize; + if let Some(replacement) = suggestion.replacements.get(0) { + let replacement_chars: Vec<_> = replacement.chars().collect(); + + chars.splice( + (suggestion.start as isize + offset) as usize + ..(suggestion.end as isize + offset) as usize, + replacement_chars.iter().cloned(), + ); + offset = offset + replacement_chars.len() as isize + - (suggestion.end - suggestion.start) as isize; + } } chars.into_iter().collect() diff --git a/nlprule/tests/tests.rs b/nlprule/tests/tests.rs index 0990939..7cca1d5 100644 --- a/nlprule/tests/tests.rs +++ b/nlprule/tests/tests.rs @@ -57,7 +57,7 @@ fn spellchecker_works() -> Result<(), Error> { let mut rules = Rules::new(RULES_PATH, TOKENIZER.clone()).unwrap(); rules.spell_mut().options_mut().variant = Some(rules.spell().variant("en_GB")?); - println!("{:?}", rules.suggest("Unicode punctuation: —")); + assert_eq!(rules.correct("color spellhceking"), "colour spellchecking"); Ok(()) } diff --git a/python/src/lib.rs b/python/src/lib.rs index c0aba6d..fd1b905 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -806,8 +806,9 @@ impl PyRules { }) } - /// Convenience method to apply suggestions to the given text. - /// Always uses the first element of `suggestion.replacements` as replacement. + /// Correct a text by applying suggestions to it. + /// - In case of multiple possible replacements, always chooses the first one. + /// - In case of a suggestion without any replacements, ignores the suggestion. /// /// Arguments: /// text (str): The input text. From b0348bbe0109bf6059a36daa282358b849fa032e Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 16 Mar 2021 21:22:16 +0100 Subject: [PATCH 16/16] fix python ci --- python/test.py | 3 ++- scripts/maturin.sh | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/test.py b/python/test.py index bc397c6..799216d 100644 --- a/python/test.py +++ b/python/test.py @@ -157,4 +157,5 @@ def test_spell_options_can_be_set(tokenizer_and_rules): def test_spellchecker_works(tokenizer_and_rules): (tokenizer, rules) = tokenizer_and_rules - print(rules.spell.search("lämp")) \ No newline at end of file + # TODO + # print(rules.spell.search("lämp")) \ No newline at end of file diff --git a/scripts/maturin.sh b/scripts/maturin.sh index 348c684..cc322dd 100755 --- a/scripts/maturin.sh +++ b/scripts/maturin.sh @@ -23,7 +23,12 @@ build_change Cargo.toml cd python -trap cleanup INT +trap ctrl_c INT + +function ctrl_c() { + cleanup + exit +} function cleanup() { # this is a bit hacky, assume we are in python/ dir @@ -33,12 +38,11 @@ function cleanup() { mv nlprule/.Cargo.toml.bak nlprule/Cargo.toml mv build/.Cargo.toml.bak build/Cargo.toml mv .Cargo.toml.bak Cargo.toml - exit } maturin $@ exit_code=$? -cleanup() +cleanup exit $exit_code \ No newline at end of file