Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33882][ML] Add a vectorized BLAS implementation #30810

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mllib-local/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
srowen marked this conversation as resolved.
Show resolved Hide resolved
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
Expand All @@ -81,6 +88,34 @@
</dependency>
</dependencies>
</profile>
<profile>
<id>vectorized</id>
<properties>
<extra.source.dir>src/vectorized/java</extra.source.dir>
luhenry marked this conversation as resolved.
Show resolved Hide resolved
</properties>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-vectorized-sources</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${extra.source.dir}</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
Expand Down
51 changes: 33 additions & 18 deletions mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,51 @@
package org.apache.spark.ml.linalg

import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}

/**
* BLAS routines for MLlib's vectors and matrices.
*/
private[spark] object BLAS extends Serializable {

@transient private var _f2jBLAS: NetlibBLAS = _
@transient private var _javaBLAS: NetlibBLAS = _
@transient private var _nativeBLAS: NetlibBLAS = _
private val nativeL1Threshold: Int = 256
luhenry marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

according to the performance test, I think we can increase nativeL1Threshold to 512?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go even as far as using nativeBLAS exclusively for level-3 operations, and never for level-1 and level-2. The cost of copying the data from managed memory to native memory (necessary to pass the array to native code) is too great relative to the small speed up of native for the level-1 and level-2 routines.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

netlib-java does not copy memory when using native backend, it uses memory pinning (which has its own problems). Please provide benchmarks to show any degradation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"small speed up of native for the level-1 and level-2 routines." I think you need to do some more analysis on this. Native can be 10x faster than JVM for reasonable sized matrices. However, as shown in https://github.com/fommil/matrix-toolkits-java the EJML and common-math project are faster for matrices of 10x10 or smaller. If you want to heavily optimise for those usecases, then swap to using EJML which is heavily optimised for that usecase (not just "something on the JVM")


// For level-1 function dspmv, use f2jBLAS for better performance.
private[ml] def f2jBLAS: NetlibBLAS = {
if (_f2jBLAS == null) {
_f2jBLAS = new F2jBLAS
// For level-1 function dspmv, use javaBLAS for better performance.
private[ml] def javaBLAS: NetlibBLAS = {
if (_javaBLAS == null) {
_javaBLAS =
try {
// scalastyle:off classforname
Class.forName("org.apache.spark.ml.linalg.VectorizedBLAS", true,
Option(Thread.currentThread().getContextClassLoader)
srowen marked this conversation as resolved.
Show resolved Hide resolved
.getOrElse(getClass.getClassLoader))
.newInstance()
.asInstanceOf[NetlibBLAS]
// scalastyle:on classforname
} catch {
case _: Throwable => new F2jBLAS
}
}
_javaBLAS
}

// For level-3 routines, we use the native BLAS.
private[ml] def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS =
if (NetlibBLAS.getInstance.isInstanceOf[F2jBLAS]) {
javaBLAS
} else {
NetlibBLAS.getInstance
luhenry marked this conversation as resolved.
Show resolved Hide resolved
}
}
_f2jBLAS
_nativeBLAS
}

private[ml] def getBLAS(vectorSize: Int): NetlibBLAS = {
if (vectorSize < nativeL1Threshold) {
f2jBLAS
javaBLAS
luhenry marked this conversation as resolved.
Show resolved Hide resolved
} else {
nativeBLAS
}
Expand Down Expand Up @@ -235,14 +258,6 @@ private[spark] object BLAS extends Serializable {
}
}

// For level-3 routines, we use the native BLAS.
private[ml] def nativeBLAS: NetlibBLAS = {
if (_nativeBLAS == null) {
_nativeBLAS = NativeBLAS
}
_nativeBLAS
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
Expand All @@ -267,7 +282,7 @@ private[spark] object BLAS extends Serializable {
x: DenseVector,
beta: Double,
y: DenseVector): Unit = {
f2jBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1)
javaBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1)
srowen marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand All @@ -279,7 +294,7 @@ private[spark] object BLAS extends Serializable {
val n = v.size
v match {
case DenseVector(values) =>
NativeBLAS.dspr("U", n, alpha, values, 1, U)
nativeBLAS.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import java.util.{Arrays, Random}
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet}

import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.annotation.Since

Expand Down Expand Up @@ -457,7 +456,7 @@ class DenseMatrix @Since("2.0.0") (
if (isTransposed) {
Iterator.tabulate(numCols) { j =>
val col = new Array[Double](numRows)
blas.dcopy(numRows, values, j, numCols, col, 0, 1)
BLAS.nativeBLAS.dcopy(numRows, values, j, numCols, col, 0, 1)
new DenseVector(col)
}
} else {
Expand Down
Loading