Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48892][ML] Avoid per-row param read in Tokenizer #47342

Closed
wants to merge 2 commits into from

Conversation

zhengruifeng
Copy link
Contributor

What changes were proposed in this pull request?

Inspired by #47258, I am checking other ML implementations, and find that we can also optimize Tokenizer in the same way

Why are the changes needed?

the function createTransformFunc is to build the udf for UnaryTransformer.transform:

val transformUDF = udf(this.createTransformFunc)

existing implementation read the params for each row.

Does this PR introduce any user-facing change?

No

How was this patch tested?

CI and manually tests:

create test dataset

spark.range(1000000).select(uuid().as("uuid")).write.mode("overwrite").parquet("/tmp/regex_tokenizer.parquet")

duration

val df = spark.read.parquet("/tmp/regex_tokenizer.parquet")
import org.apache.spark.ml.feature._
val tokenizer = new RegexTokenizer().setPattern("-").setInputCol("uuid")
Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()) // warm up
val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic

result (before this PR)

scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720613235068
val res5: Long = 50397

result (after this PR)

scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720612871256
val res5: Long = 43748

Was this patch authored or co-authored using generative AI tooling?

No

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine, it adds some complexity, but not much

Copy link
Contributor

@JoshRosen JoshRosen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll wager that the expensive part was probably the configuration check itself plus the regex compilation, but not the branching (since those would predict well). Therefore I predict that you can get almost the whole speedup if you did something like

override protected def createTransformFunc: String => Seq[String] = {
  val re = $(pattern).r
  val _toLowercase = $(toLowercase)
  val _gaps = $(gaps)
  val minLength = $(minTokenLength)
  { originStr =>
      // scalastyle:off caselocale
      val str = if (_toLowerCase) originStr.toLowerCase() else originStr
      // scalastyle:on caselocale
      val tokens = if (_gaps) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq
      tokens.filter(_.length >= minLength)
    }
}

Basically, I think it might be overkill or unnecessary to fully inline and expand the cross product like this. My suggested approach is easier to understand and probably nearly equivalent in performance.

@zhengruifeng
Copy link
Contributor Author

I'll wager that the expensive part was probably the configuration check itself plus the regex compilation, but not the branching (since those would predict well). Therefore I predict that you can get almost the whole speedup if you did something like

override protected def createTransformFunc: String => Seq[String] = {
  val re = $(pattern).r
  val _toLowercase = $(toLowercase)
  val _gaps = $(gaps)
  val minLength = $(minTokenLength)
  { originStr =>
      // scalastyle:off caselocale
      val str = if (_toLowerCase) originStr.toLowerCase() else originStr
      // scalastyle:on caselocale
      val tokens = if (_gaps) re.split(str).toImmutableArraySeq else re.findAllIn(str).toSeq
      tokens.filter(_.length >= minLength)
    }
}

Basically, I think it might be overkill or unnecessary to fully inline and expand the cross product like this. My suggested approach is easier to understand and probably nearly equivalent in performance.

make sense, let me make the changes simple

@zhengruifeng
Copy link
Contributor Author

merged to master

@zhengruifeng zhengruifeng deleted the opt_tokenizer branch July 16, 2024 23:18
jingz-db pushed a commit to jingz-db/spark that referenced this pull request Jul 22, 2024
### What changes were proposed in this pull request?
Inspired by apache#47258, I am checking other ML implementations, and find that we can also optimize `Tokenizer` in the same way

### Why are the changes needed?
the function `createTransformFunc` is to build the udf for `UnaryTransformer.transform`:
https://github.com/apache/spark/blob/d679dabdd1b5ad04b8c7deb1c06ce886a154a928/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala#L118

existing implementation read the params for each row.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI and manually tests:

create test dataset
```
spark.range(1000000).select(uuid().as("uuid")).write.mode("overwrite").parquet("/tmp/regex_tokenizer.parquet")
```

duration
```
val df = spark.read.parquet("/tmp/regex_tokenizer.parquet")
import org.apache.spark.ml.feature._
val tokenizer = new RegexTokenizer().setPattern("-").setInputCol("uuid")
Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()) // warm up
val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
```

result (before this PR)
```
scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720613235068
val res5: Long = 50397
```

result (after this PR)
```
scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720612871256
val res5: Long = 43748
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47342 from zhengruifeng/opt_tokenizer.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
attilapiros pushed a commit to attilapiros/spark that referenced this pull request Oct 4, 2024
### What changes were proposed in this pull request?
Inspired by apache#47258, I am checking other ML implementations, and find that we can also optimize `Tokenizer` in the same way

### Why are the changes needed?
the function `createTransformFunc` is to build the udf for `UnaryTransformer.transform`:
https://github.com/apache/spark/blob/d679dabdd1b5ad04b8c7deb1c06ce886a154a928/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala#L118

existing implementation read the params for each row.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI and manually tests:

create test dataset
```
spark.range(1000000).select(uuid().as("uuid")).write.mode("overwrite").parquet("/tmp/regex_tokenizer.parquet")
```

duration
```
val df = spark.read.parquet("/tmp/regex_tokenizer.parquet")
import org.apache.spark.ml.feature._
val tokenizer = new RegexTokenizer().setPattern("-").setInputCol("uuid")
Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()) // warm up
val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
```

result (before this PR)
```
scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720613235068
val res5: Long = 50397
```

result (after this PR)
```
scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()); System.currentTimeMillis - tic
val tic: Long = 1720612871256
val res5: Long = 43748
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47342 from zhengruifeng/opt_tokenizer.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants