Skip to content

Commit

Permalink
[SPARK-4717][MLlib] Optimize BLAS library to avoid de-reference multi…
Browse files Browse the repository at this point in the history
…ple times in loop

Have a local reference to `values` and `indices` array in the `Vector` object
so JVM can locate the value with one operation call. See `SPARK-4581`
for similar optimization, and the bytecode analysis.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes apache#3577 from dbtsai/blasopt and squashes the following commits:

62d38c4 [DB Tsai] formating
0316cef [DB Tsai] first commit
  • Loading branch information
DB Tsai authored and mengxr committed Dec 3, 2014
1 parent 7fc49ed commit d005429
Showing 1 changed file with 60 additions and 39 deletions.
99 changes: 60 additions & 39 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,21 @@ private[spark] object BLAS extends Serializable with Logging {
* y += a * x
*/
private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = {
val nnz = x.indices.size
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.size

if (a == 1.0) {
var k = 0
while (k < nnz) {
y.values(x.indices(k)) += x.values(k)
yValues(xIndices(k)) += xValues(k)
k += 1
}
} else {
var k = 0
while (k < nnz) {
y.values(x.indices(k)) += a * x.values(k)
yValues(xIndices(k)) += a * xValues(k)
k += 1
}
}
Expand Down Expand Up @@ -119,11 +123,15 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
private def dot(x: SparseVector, y: DenseVector): Double = {
val nnz = x.indices.size
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val nnz = xIndices.size

var sum = 0.0
var k = 0
while (k < nnz) {
sum += x.values(k) * y.values(x.indices(k))
sum += xValues(k) * yValues(xIndices(k))
k += 1
}
sum
Expand All @@ -133,19 +141,24 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
private def dot(x: SparseVector, y: SparseVector): Double = {
val xValues = x.values
val xIndices = x.indices
val yValues = y.values
val yIndices = y.indices
val nnzx = xIndices.size
val nnzy = yIndices.size

var kx = 0
val nnzx = x.indices.size
var ky = 0
val nnzy = y.indices.size
var sum = 0.0
// y catching x
while (kx < nnzx && ky < nnzy) {
val ix = x.indices(kx)
while (ky < nnzy && y.indices(ky) < ix) {
val ix = xIndices(kx)
while (ky < nnzy && yIndices(ky) < ix) {
ky += 1
}
if (ky < nnzy && y.indices(ky) == ix) {
sum += x.values(kx) * y.values(ky)
if (ky < nnzy && yIndices(ky) == ix) {
sum += xValues(kx) * yValues(ky)
ky += 1
}
kx += 1
Expand All @@ -163,21 +176,25 @@ private[spark] object BLAS extends Serializable with Logging {
case dy: DenseVector =>
x match {
case sx: SparseVector =>
val sxIndices = sx.indices
val sxValues = sx.values
val dyValues = dy.values
val nnz = sxIndices.size

var i = 0
var k = 0
val nnz = sx.indices.size
while (k < nnz) {
val j = sx.indices(k)
val j = sxIndices(k)
while (i < j) {
dy.values(i) = 0.0
dyValues(i) = 0.0
i += 1
}
dy.values(i) = sx.values(k)
dyValues(i) = sxValues(k)
i += 1
k += 1
}
while (i < n) {
dy.values(i) = 0.0
dyValues(i) = 0.0
i += 1
}
case dx: DenseVector =>
Expand Down Expand Up @@ -311,6 +328,8 @@ private[spark] object BLAS extends Serializable with Logging {
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB")

val Avals = A.values
val Bvals = B.values
val Cvals = C.values
val Arows = if (!transA) A.rowIndices else A.colPtrs
val Acols = if (!transA) A.colPtrs else A.rowIndices

Expand All @@ -327,11 +346,11 @@ private[spark] object BLAS extends Serializable with Logging {
val indEnd = Arows(rowCounterForA + 1)
var sum = 0.0
while (i < indEnd) {
sum += Avals(i) * B.values(Bstart + Acols(i))
sum += Avals(i) * Bvals(Bstart + Acols(i))
i += 1
}
val Cindex = Cstart + rowCounterForA
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounterForA += 1
}
colCounterForB += 1
Expand All @@ -349,15 +368,15 @@ private[spark] object BLAS extends Serializable with Logging {
i += 1
}
val Cindex = Cstart + rowCounter
C.values(Cindex) = beta * C.values(Cindex) + sum * alpha
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha
rowCounter += 1
}
colCounterForB += 1
}
}
} else {
// Scale matrix first if `beta` is not equal to 0.0
if (beta != 0.0){
if (beta != 0.0) {
f2jBLAS.dscal(C.values.length, beta, C.values, 1)
}
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of
Expand All @@ -371,9 +390,9 @@ private[spark] object BLAS extends Serializable with Logging {
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B.values(Bstart + colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval
val Bval = Bvals(Bstart + colCounterForA) * alpha
while (i < indEnd) {
Cvals(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand All @@ -384,12 +403,12 @@ private[spark] object BLAS extends Serializable with Logging {
while (colCounterForB < nB) {
var colCounterForA = 0 // The column of A to multiply with the row of B
val Cstart = colCounterForB * mA
while (colCounterForA < kA){
while (colCounterForA < kA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val Bval = B(colCounterForB, colCounterForA) * alpha
while (i < indEnd){
C.values(Cstart + Arows(i)) += Avals(i) * Bval
while (i < indEnd) {
Cvals(Cstart + Arows(i)) += Avals(i) * Bval
i += 1
}
colCounterForA += 1
Expand Down Expand Up @@ -484,41 +503,43 @@ private[spark] object BLAS extends Serializable with Logging {
beta: Double,
y: DenseVector): Unit = {

val mA: Int = if(!trans) A.numRows else A.numCols
val nA: Int = if(!trans) A.numCols else A.numRows
val xValues = x.values
val yValues = y.values

val mA: Int = if (!trans) A.numRows else A.numCols
val nA: Int = if (!trans) A.numCols else A.numRows

val Avals = A.values
val Arows = if (!trans) A.rowIndices else A.colPtrs
val Acols = if (!trans) A.colPtrs else A.rowIndices

// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices
if (trans){
if (trans) {
var rowCounter = 0
while (rowCounter < mA){
while (rowCounter < mA) {
var i = Arows(rowCounter)
val indEnd = Arows(rowCounter + 1)
var sum = 0.0
while(i < indEnd){
sum += Avals(i) * x.values(Acols(i))
while (i < indEnd) {
sum += Avals(i) * xValues(Acols(i))
i += 1
}
y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha
yValues(rowCounter) = beta * yValues(rowCounter) + sum * alpha
rowCounter += 1
}
} else {
// Scale vector first if `beta` is not equal to 0.0
if (beta != 0.0){
if (beta != 0.0) {
scal(beta, y)
}
// Perform matrix-vector multiplication and add to y
var colCounterForA = 0
while (colCounterForA < nA){
while (colCounterForA < nA) {
var i = Acols(colCounterForA)
val indEnd = Acols(colCounterForA + 1)
val xVal = x.values(colCounterForA) * alpha
while (i < indEnd){
val xVal = xValues(colCounterForA) * alpha
while (i < indEnd) {
val rowIndex = Arows(i)
y.values(rowIndex) += Avals(i) * xVal
yValues(rowIndex) += Avals(i) * xVal
i += 1
}
colCounterForA += 1
Expand Down

0 comments on commit d005429

Please sign in to comment.