Skip to content

Commit

Permalink
[SPARK-48892][ML] Avoid per-row param read in Tokenizer
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Inspired by apache#47258, I am checking other ML implementations, and find that we can also optimize `Tokenizer` in the same way

### Why are the changes needed?
the function `createTransformFunc` is to build the udf for `UnaryTransformer.transform`:
https://github.com/apache/spark/blob/d679dabdd1b5ad04b8c7deb1c06ce886a154a928/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala#L118

existing implementation read the params for each row.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI and manually tests:

create test dataset
```
spark.range(1000000).select(uuid().as("uuid")).write.mode("overwrite").parquet("/tmp/regex_tokenizer.parquet")
```

duration
```
val df = spark.read.parquet("/tmp/regex_tokenizer.parquet")
import org.apache.spark.ml.feature._
val tokenizer = new RegexTokenizer().setPattern("-").setInputCol("uuid")
Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()) // warm up
val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
```

result (before this PR)
```
scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720613235068
val res5: Long = 50397
```

result (after this PR)
```
scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720612871256
val res5: Long = 43748
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47342 from zhengruifeng/opt_tokenizer.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng authored and attilapiros committed Oct 4, 2024
1 parent 3ec85f9 commit 08e2f11
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,19 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)

setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true)

override protected def createTransformFunc: String => Seq[String] = { originStr =>
override protected def createTransformFunc: String => Seq[String] = {
val re = $(pattern).r
// scalastyle:off caselocale
val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
// scalastyle:on caselocale
val tokens = if ($(gaps)) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq
val minLength = $(minTokenLength)
tokens.filter(_.length >= minLength)
val localToLowercase = $(toLowercase)
val localGaps = $(gaps)
val localMinTokenLength = $(minTokenLength)

(originStr: String) => {
// scalastyle:off caselocale
val str = if (localToLowercase) originStr.toLowerCase() else originStr
// scalastyle:on caselocale
val tokens = if (localGaps) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq
tokens.filter(_.length >= localMinTokenLength)
}
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down

0 comments on commit 08e2f11

Please sign in to comment.