From 3aae4f0897050aac0792724d52406d7dd9e1070e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 16 Jul 2024 08:32:32 +0800 Subject: [PATCH] address comments --- .../apache/spark/ml/feature/Tokenizer.scala | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 5022ad12d9bc5..1acbfd781820f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -143,25 +143,16 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) override protected def createTransformFunc: String => Seq[String] = { val re = $(pattern).r + val localToLowercase = $(toLowercase) + val localGaps = $(gaps) val localMinTokenLength = $(minTokenLength) - ($(toLowercase), $(gaps)) match { - case (true, true) => - (originStr: String) => - re.split(originStr.toLowerCase()).toImmutableArraySeq - .filter(_.length >= localMinTokenLength) - - case (true, false) => - (originStr: String) => re.findAllIn(originStr.toLowerCase()).toSeq - .filter(_.length >= localMinTokenLength) - - case (false, true) => - (originStr: String) => re.split(originStr).toImmutableArraySeq - .filter(_.length >= localMinTokenLength) - - case (false, false) => - (originStr: String) => re.findAllIn(originStr).toSeq - .filter(_.length >= localMinTokenLength) + (originStr: String) => { + // scalastyle:off caselocale + val str = if (localToLowercase) originStr.toLowerCase() else originStr + // scalastyle:on caselocale + val tokens = if (localGaps) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq + tokens.filter(_.length >= localMinTokenLength) } }