From 08e2f1171b86d90abbd5ae7bda584eb05905a51c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 17 Jul 2024 07:18:18 +0800 Subject: [PATCH] [SPARK-48892][ML] Avoid per-row param read in `Tokenizer` ### What changes were proposed in this pull request? Inspired by https://github.com/apache/spark/pull/47258, I am checking other ML implementations, and find that we can also optimize `Tokenizer` in the same way ### Why are the changes needed? the function `createTransformFunc` is to build the udf for `UnaryTransformer.transform`: https://github.com/apache/spark/blob/d679dabdd1b5ad04b8c7deb1c06ce886a154a928/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala#L118 existing implementation read the params for each row. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI and manually tests: create test dataset ``` spark.range(1000000).select(uuid().as("uuid")).write.mode("overwrite").parquet("/tmp/regex_tokenizer.parquet") ``` duration ``` val df = spark.read.parquet("/tmp/regex_tokenizer.parquet") import org.apache.spark.ml.feature._ val tokenizer = new RegexTokenizer().setPattern("-").setInputCol("uuid") Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()) // warm up val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic ``` result (before this PR) ``` scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic val tic: Long = 1720613235068 val res5: Long = 50397 ``` result (after this PR) ``` scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic val tic: Long = 1720612871256 val res5: Long = 43748 ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #47342 from zhengruifeng/opt_tokenizer. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../apache/spark/ml/feature/Tokenizer.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e7b3ff76a8d8c..1acbfd781820f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -141,14 +141,19 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) - override protected def createTransformFunc: String => Seq[String] = { originStr => + override protected def createTransformFunc: String => Seq[String] = { val re = $(pattern).r - // scalastyle:off caselocale - val str = if ($(toLowercase)) originStr.toLowerCase() else originStr - // scalastyle:on caselocale - val tokens = if ($(gaps)) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq - val minLength = $(minTokenLength) - tokens.filter(_.length >= minLength) + val localToLowercase = $(toLowercase) + val localGaps = $(gaps) + val localMinTokenLength = $(minTokenLength) + + (originStr: String) => { + // scalastyle:off caselocale + val str = if (localToLowercase) originStr.toLowerCase() else originStr + // scalastyle:on caselocale + val tokens = if (localGaps) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq + tokens.filter(_.length >= localMinTokenLength) + } } override protected def validateInputType(inputType: DataType): Unit = {