Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
init
  • Loading branch information
zhengruifeng committed Jul 16, 2024
1 parent 23080ac commit cacf9d6
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,28 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)

setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)

override protected def createTransformFunc: String => Seq[String] = { originStr =>
override protected def createTransformFunc: String => Seq[String] = {
val re = $(pattern).r
// scalastyle:off caselocale
val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
// scalastyle:on caselocale
val tokens = if ($(gaps)) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq
val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)
val localMinTokenLength = $(minTokenLength)

($(toLowercase), $(gaps)) match {
case (true, true) =>
(originStr: String) =>
re.split(originStr.toLowerCase()).toImmutableArraySeq
.filter(_.length >= localMinTokenLength)

case (true, false) =>
(originStr: String) => re.findAllIn(originStr.toLowerCase()).toSeq
.filter(_.length >= localMinTokenLength)

case (false, true) =>
(originStr: String) => re.split(originStr).toImmutableArraySeq
.filter(_.length >= localMinTokenLength)

case (false, false) =>
(originStr: String) => re.findAllIn(originStr).toSeq
.filter(_.length >= localMinTokenLength)
}
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down

0 comments on commit cacf9d6

Please sign in to comment.