diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e7b3ff76a8d8c..1acbfd781820f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -141,14 +141,19 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) - override protected def createTransformFunc: String => Seq[String] = { originStr => + override protected def createTransformFunc: String => Seq[String] = { val re = $(pattern).r - // scalastyle:off caselocale - val str = if ($(toLowercase)) originStr.toLowerCase() else originStr - // scalastyle:on caselocale - val tokens = if ($(gaps)) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq - val minLength = $(minTokenLength) - tokens.filter(_.length >= minLength) + val localToLowercase = $(toLowercase) + val localGaps = $(gaps) + val localMinTokenLength = $(minTokenLength) + + (originStr: String) => { + // scalastyle:off caselocale + val str = if (localToLowercase) originStr.toLowerCase() else originStr + // scalastyle:on caselocale + val tokens = if (localGaps) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq + tokens.filter(_.length >= localMinTokenLength) + } } override protected def validateInputType(inputType: DataType): Unit = {