Skip to content

Commit

Permalink
Adding segmenter model option to datagen (#3669)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertbastian authored Jul 22, 2023
1 parent 105e7ec commit 10ab02f
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 1,003,610 deletions.
18 changes: 18 additions & 0 deletions provider/datagen/src/bin/datagen/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ pub struct Cli {

#[arg(short, long, value_enum, default_value_t = Fallback::Legacy)]
fallback: Fallback,

#[arg(long, num_args = 0.., default_value = "recommended")]
#[arg(
help = "Include these segmenter models in the output. Accepts multiple arguments. \
Defaults to 'recommended' for the recommended set of models. Use 'none' for no models"
)]
segmenter_models: Vec<String>,
}

impl Cli {
Expand Down Expand Up @@ -307,6 +314,7 @@ impl Cli {
.iter()
.map(|c| c.to_datagen_value().to_owned())
.collect(),
segmenter_models: self.make_segmenter_models()?,
export: self.make_exporter()?,
fallback: match self.fallback {
Fallback::Legacy => config::FallbackMode::Legacy,
Expand Down Expand Up @@ -405,6 +413,16 @@ impl Cli {
})
}

fn make_segmenter_models(&self) -> eyre::Result<options::SegmenterModelInclude> {
Ok(if self.segmenter_models.as_slice() == ["none"] {
config::SegmenterModelInclude::None
} else if self.segmenter_models.as_slice() == ["recommended"] {
config::SegmenterModelInclude::Recommended
} else {
config::SegmenterModelInclude::Explicit(self.segmenter_models.clone())
})
}

fn make_exporter(&self) -> eyre::Result<config::Export> {
match self.format {
v @ (Format::Dir | Format::DeprecatedDefault) => {
Expand Down
2 changes: 2 additions & 0 deletions provider/datagen/src/bin/datagen/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub struct Config {
pub collation_han_database: CollationHanDatabase,
#[serde(default, skip_serializing_if = "is_default")]
pub collations: HashSet<String>,
#[serde(default, skip_serializing_if = "is_default")]
pub segmenter_models: SegmenterModelInclude,
pub export: Export,
#[serde(default, skip_serializing_if = "is_default")]
pub fallback: FallbackMode,
Expand Down
20 changes: 20 additions & 0 deletions provider/datagen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,26 @@ pub fn datagen(
)
})
.unwrap_or(options::LocaleInclude::All),
segmenter_models: match locales {
None => options::SegmenterModelInclude::Recommended,
Some(list) => options::SegmenterModelInclude::Explicit({
let mut models = vec![];
for locale in list {
let locale = locale.into();
if let Some(model) =
transform::segmenter::lstm::data_locale_to_model_name(&locale)
{
models.push(model.into());
}
if let Some(model) =
transform::segmenter::dictionary::data_locale_to_model_name(&locale)
{
models.push(model.into());
}
}
models
}),
},
..source.options.clone()
},
{
Expand Down
42 changes: 42 additions & 0 deletions provider/datagen/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub struct Options {
/// The type of fallback that the data should be generated for. If locale fallback is
/// used at runtime, smaller data can be generated.
pub fallback: FallbackMode,
/// The segmentation models to include
pub segmenter_models: SegmenterModelInclude,
}

/// Defines the locales to include
Expand Down Expand Up @@ -138,3 +140,43 @@ impl Default for TrieType {
Self::Small
}
}

#[non_exhaustive]
#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
/// The segmentation models to include
pub enum SegmenterModelInclude {
/// Include the recommended set of models. This will cover all languages supported
/// by ICU4X: Thai, Burmese, Khmer, Lao, Chinese, and Japanese. Both dictionary
/// and LSTM models will be included, to the extent required by the chosen data keys.
Recommended,
/// Include no dictionary or LSTM models. This will make line and word segmenters
/// behave like simple rule-based segmenters, which will be incorrect when handling text
/// that contains Thai, Burmese, Khmer, Lao, Chinese, or Japanese.
None,
/// Include an explicit list of LSTM or dictionary models, to the extent required by the
/// chosen data keys.
///
/// The currently supported dictionary models are
/// * `cjdict`
/// * `burmesedict`
/// * `khmerdict`
/// * `laodict`
/// * `thaidict`
///
/// The currently supported LSTM models are
/// * `Burmese_codepoints_exclusive_model4_heavy`
/// * `Khmer_codepoints_exclusive_model4_heavy`
/// * `Lao_codepoints_exclusive_model4_heavy`
/// * `Thai_codepoints_exclusive_model4_heavy`
///
/// If a model is not included, the resulting line or word segmenter will apply rule-based
/// segmentation when encountering text in a script that requires the model, which will be
/// incorrect.
Explicit(Vec<String>),
}

impl Default for SegmenterModelInclude {
fn default() -> Self {
Self::Recommended
}
}
73 changes: 45 additions & 28 deletions provider/datagen/src/transform/segmenter/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,46 @@ struct SegmenterDictionaryData {
trie_data: Vec<u16>,
}

fn model_name_to_data_locale(name: &str) -> Option<DataLocale> {
match name {
"khmerdict" => Some(langid!("km").into()),
"cjdict" => Some(langid!("ja").into()),
"laodict" => Some(langid!("lo").into()),
"burmesedict" => Some(langid!("my").into()),
"thaidict" => Some(langid!("th").into()),
_ => None,
}
}

pub(crate) fn data_locale_to_model_name(locale: &DataLocale) -> Option<&'static str> {
match locale.get_langid() {
id if id == langid!("km") => Some("khmerdict"),
id if id == langid!("ja") => Some("cjdict"),
id if id == langid!("lo") => Some("laodict"),
id if id == langid!("my") => Some("burmesedict"),
id if id == langid!("th") => Some("thaidict"),
_ => None,
}
}

impl crate::DatagenProvider {
fn load_dictionary_data(
&self,
req: DataRequest,
) -> Result<UCharDictionaryBreakDataV1<'static>, DataError> {
let filename = if req.locale.get_langid() == langid!("km") {
"segmenter/dictionary/khmerdict.toml"
} else if req.locale.get_langid() == langid!("ja") {
"segmenter/dictionary/cjdict.toml"
} else if req.locale.get_langid() == langid!("lo") {
"segmenter/dictionary/laodict.toml"
} else if req.locale.get_langid() == langid!("my") {
"segmenter/dictionary/burmesedict.toml"
} else if req.locale.get_langid() == langid!("th") {
"segmenter/dictionary/thaidict.toml"
} else {
Err(DataErrorKind::MissingLocale.into_error())?
};
let model = data_locale_to_model_name(req.locale)
.ok_or(DataErrorKind::MissingLocale.into_error())?;

let filename = format!("segmenter/dictionary/{model}.toml");

let toml_data: &SegmenterDictionaryData = self
.source
.icuexport()
.and_then(|e| e.read_and_parse_toml(filename))
.and_then(|e| e.read_and_parse_toml(&filename))
.or_else(|e| {
self.source
.icuexport_fallback()
.read_and_parse_toml(filename)
.read_and_parse_toml(&filename)
.map_err(|_| e)
})?;

Expand All @@ -51,12 +64,9 @@ impl crate::DatagenProvider {
}

macro_rules! implement {
($marker:ident, $($locale:literal),*) => {
($marker:ident, $supported:expr) => {
impl DataProvider<$marker> for crate::DatagenProvider {
fn load(
&self,
req: DataRequest,
) -> Result<DataResponse<$marker>, DataError> {
fn load(&self, req: DataRequest) -> Result<DataResponse<$marker>, DataError> {
self.check_req::<$marker>(req)?;
let data = self.load_dictionary_data(req)?;
Ok(DataResponse {
Expand All @@ -67,19 +77,26 @@ macro_rules! implement {
}

impl IterableDataProvider<$marker> for crate::DatagenProvider {
// TODO(#3408): Do we actually want to filter these by the user-selected locales?
fn supported_locales(&self) -> Result<Vec<DataLocale>, DataError> {
Ok(self.filter_data_locales(vec![$(locale!($locale).into()),*]))
Ok(match &self.source.options.segmenter_models {
crate::options::SegmenterModelInclude::Recommended => $supported
.into_iter()
.filter_map(model_name_to_data_locale)
.collect(),
crate::options::SegmenterModelInclude::None => Vec::new(),
crate::options::SegmenterModelInclude::Explicit(list) => $supported
.into_iter()
.filter(|&model| list.iter().any(|x| x == model))
.filter_map(model_name_to_data_locale)
.collect(),
})
}
}
}
};
}

implement!(DictionaryForWordOnlyAutoV1Marker, "ja");
implement!(DictionaryForWordOnlyAutoV1Marker, ["cjdict"]);
implement!(
DictionaryForWordLineExtendedV1Marker,
"th",
"km",
"lo",
"my"
["khmerdict", "laodict", "burmesedict", "thaidict"]
);
60 changes: 40 additions & 20 deletions provider/datagen/src/transform/segmenter/lstm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,35 @@ convert!(ndarray_to_lstm_matrix1, LstmMatrix1, 1);
convert!(ndarray_to_lstm_matrix2, LstmMatrix2, 2);
convert!(ndarray_to_lstm_matrix3, LstmMatrix3, 3);

fn model_name_to_data_locale(name: &str) -> Option<DataLocale> {
match name {
"Burmese_codepoints_exclusive_model4_heavy" => Some(langid!("my").into()),
"Khmer_codepoints_exclusive_model4_heavy" => Some(langid!("km").into()),
"Lao_codepoints_exclusive_model4_heavy" => Some(langid!("lo").into()),
"Thai_codepoints_exclusive_model4_heavy" => Some(langid!("th").into()),
_ => None,
}
}

pub(crate) fn data_locale_to_model_name(locale: &DataLocale) -> Option<&'static str> {
match locale.get_langid() {
id if id == langid!("my") => Some("Burmese_codepoints_exclusive_model4_heavy"),
id if id == langid!("km") => Some("Khmer_codepoints_exclusive_model4_heavy"),
id if id == langid!("lo") => Some("Lao_codepoints_exclusive_model4_heavy"),
id if id == langid!("th") => Some("Thai_codepoints_exclusive_model4_heavy"),
_ => None,
}
}

impl DataProvider<LstmForWordLineAutoV1Marker> for crate::DatagenProvider {
fn load(
&self,
req: DataRequest,
) -> Result<DataResponse<LstmForWordLineAutoV1Marker>, DataError> {
self.check_req::<LstmForWordLineAutoV1Marker>(req)?;
let model = if req.locale.language() == langid!("th").language {
"Thai_codepoints_exclusive_model4_heavy"
} else if req.locale.language() == langid!("my").language {
"Burmese_codepoints_exclusive_model4_heavy"
} else if req.locale.language() == langid!("lo").language {
"Lao_codepoints_exclusive_model4_heavy"
} else if req.locale.language() == langid!("km").language {
"Khmer_codepoints_exclusive_model4_heavy"
} else {
return Err(
DataErrorKind::MissingLocale.with_req(LstmForWordLineAutoV1Marker::KEY, req)
);
};

let model = data_locale_to_model_name(req.locale)
.ok_or(DataErrorKind::MissingLocale.with_req(LstmForWordLineAutoV1Marker::KEY, req))?;

let lstm_data = self
.source
Expand All @@ -218,13 +228,23 @@ impl DataProvider<LstmForWordLineAutoV1Marker> for crate::DatagenProvider {

impl IterableDataProvider<LstmForWordLineAutoV1Marker> for crate::DatagenProvider {
fn supported_locales(&self) -> Result<Vec<DataLocale>, DataError> {
// TODO(#3408): Do we actually want to filter these by the user-selected locales?
Ok(self.filter_data_locales(vec![
langid!("km").into(),
langid!("lo").into(),
langid!("my").into(),
langid!("th").into(),
]))
Ok(match &self.source.options.segmenter_models {
crate::options::SegmenterModelInclude::Recommended => [
"Burmese_codepoints_exclusive_model4_heavy",
"Khmer_codepoints_exclusive_model4_heavy",
"Lao_codepoints_exclusive_model4_heavy",
"Thai_codepoints_exclusive_model4_heavy",
]
.into_iter()
.filter_map(model_name_to_data_locale)
.collect(),
crate::options::SegmenterModelInclude::None => Vec::new(),
crate::options::SegmenterModelInclude::Explicit(list) => list
.iter()
.map(core::ops::Deref::deref)
.filter_map(model_name_to_data_locale)
.collect(),
})
}
}

Expand Down
4 changes: 2 additions & 2 deletions provider/datagen/src/transform/segmenter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use icu_segmenter::symbols::*;
use std::fmt::Debug;
use zerovec::ZeroVec;

mod dictionary;
mod lstm;
pub(crate) mod dictionary;
pub(crate) mod lstm;

// state machine name define by builtin name
// [[tables]]
Expand Down
Loading

0 comments on commit 10ab02f

Please sign in to comment.