From d00542987ed80635782dcc826fc0bdbf434fff10 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 3 Dec 2014 22:31:39 +0800 Subject: [PATCH] [SPARK-4717][MLlib] Optimize BLAS library to avoid de-reference multiple times in loop Have a local reference to `values` and `indices` array in the `Vector` object so JVM can locate the value with one operation call. See `SPARK-4581` for similar optimization, and the bytecode analysis. Author: DB Tsai Closes #3577 from dbtsai/blasopt and squashes the following commits: 62d38c4 [DB Tsai] formating 0316cef [DB Tsai] first commit --- .../org/apache/spark/mllib/linalg/BLAS.scala | 99 +++++++++++-------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 89539e600f48c..8c4c9c6cf6ae2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -72,17 +72,21 @@ private[spark] object BLAS extends Serializable with Logging { * y += a * x */ private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { - val nnz = x.indices.size + val xValues = x.values + val xIndices = x.indices + val yValues = y.values + val nnz = xIndices.size + if (a == 1.0) { var k = 0 while (k < nnz) { - y.values(x.indices(k)) += x.values(k) + yValues(xIndices(k)) += xValues(k) k += 1 } } else { var k = 0 while (k < nnz) { - y.values(x.indices(k)) += a * x.values(k) + yValues(xIndices(k)) += a * xValues(k) k += 1 } } @@ -119,11 +123,15 @@ private[spark] object BLAS extends Serializable with Logging { * dot(x, y) */ private def dot(x: SparseVector, y: DenseVector): Double = { - val nnz = x.indices.size + val xValues = x.values + val xIndices = x.indices + val yValues = y.values + val nnz = xIndices.size + var sum = 0.0 var k = 0 while (k < nnz) { - sum += x.values(k) * y.values(x.indices(k)) + sum += xValues(k) * yValues(xIndices(k)) k += 1 } sum @@ -133,19 +141,24 @@ private[spark] object BLAS extends Serializable with Logging { * dot(x, y) */ private def dot(x: SparseVector, y: SparseVector): Double = { + val xValues = x.values + val xIndices = x.indices + val yValues = y.values + val yIndices = y.indices + val nnzx = xIndices.size + val nnzy = yIndices.size + var kx = 0 - val nnzx = x.indices.size var ky = 0 - val nnzy = y.indices.size var sum = 0.0 // y catching x while (kx < nnzx && ky < nnzy) { - val ix = x.indices(kx) - while (ky < nnzy && y.indices(ky) < ix) { + val ix = xIndices(kx) + while (ky < nnzy && yIndices(ky) < ix) { ky += 1 } - if (ky < nnzy && y.indices(ky) == ix) { - sum += x.values(kx) * y.values(ky) + if (ky < nnzy && yIndices(ky) == ix) { + sum += xValues(kx) * yValues(ky) ky += 1 } kx += 1 @@ -163,21 +176,25 @@ private[spark] object BLAS extends Serializable with Logging { case dy: DenseVector => x match { case sx: SparseVector => + val sxIndices = sx.indices + val sxValues = sx.values + val dyValues = dy.values + val nnz = sxIndices.size + var i = 0 var k = 0 - val nnz = sx.indices.size while (k < nnz) { - val j = sx.indices(k) + val j = sxIndices(k) while (i < j) { - dy.values(i) = 0.0 + dyValues(i) = 0.0 i += 1 } - dy.values(i) = sx.values(k) + dyValues(i) = sxValues(k) i += 1 k += 1 } while (i < n) { - dy.values(i) = 0.0 + dyValues(i) = 0.0 i += 1 } case dx: DenseVector => @@ -311,6 +328,8 @@ private[spark] object BLAS extends Serializable with Logging { s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") val Avals = A.values + val Bvals = B.values + val Cvals = C.values val Arows = if (!transA) A.rowIndices else A.colPtrs val Acols = if (!transA) A.colPtrs else A.rowIndices @@ -327,11 +346,11 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B.values(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + Acols(i)) i += 1 } val Cindex = Cstart + rowCounterForA - C.values(Cindex) = beta * C.values(Cindex) + sum * alpha + Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha rowCounterForA += 1 } colCounterForB += 1 @@ -349,7 +368,7 @@ private[spark] object BLAS extends Serializable with Logging { i += 1 } val Cindex = Cstart + rowCounter - C.values(Cindex) = beta * C.values(Cindex) + sum * alpha + Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha rowCounter += 1 } colCounterForB += 1 @@ -357,7 +376,7 @@ private[spark] object BLAS extends Serializable with Logging { } } else { // Scale matrix first if `beta` is not equal to 0.0 - if (beta != 0.0){ + if (beta != 0.0) { f2jBLAS.dscal(C.values.length, beta, C.values, 1) } // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of @@ -371,9 +390,9 @@ private[spark] object BLAS extends Serializable with Logging { while (colCounterForA < kA) { var i = Acols(colCounterForA) val indEnd = Acols(colCounterForA + 1) - val Bval = B.values(Bstart + colCounterForA) * alpha - while (i < indEnd){ - C.values(Cstart + Arows(i)) += Avals(i) * Bval + val Bval = Bvals(Bstart + colCounterForA) * alpha + while (i < indEnd) { + Cvals(Cstart + Arows(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -384,12 +403,12 @@ private[spark] object BLAS extends Serializable with Logging { while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA - while (colCounterForA < kA){ + while (colCounterForA < kA) { var i = Acols(colCounterForA) val indEnd = Acols(colCounterForA + 1) val Bval = B(colCounterForB, colCounterForA) * alpha - while (i < indEnd){ - C.values(Cstart + Arows(i)) += Avals(i) * Bval + while (i < indEnd) { + Cvals(Cstart + Arows(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -484,41 +503,43 @@ private[spark] object BLAS extends Serializable with Logging { beta: Double, y: DenseVector): Unit = { - val mA: Int = if(!trans) A.numRows else A.numCols - val nA: Int = if(!trans) A.numCols else A.numRows + val xValues = x.values + val yValues = y.values + + val mA: Int = if (!trans) A.numRows else A.numCols + val nA: Int = if (!trans) A.numCols else A.numRows val Avals = A.values val Arows = if (!trans) A.rowIndices else A.colPtrs val Acols = if (!trans) A.colPtrs else A.rowIndices - // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans){ + if (trans) { var rowCounter = 0 - while (rowCounter < mA){ + while (rowCounter < mA) { var i = Arows(rowCounter) val indEnd = Arows(rowCounter + 1) var sum = 0.0 - while(i < indEnd){ - sum += Avals(i) * x.values(Acols(i)) + while (i < indEnd) { + sum += Avals(i) * xValues(Acols(i)) i += 1 } - y.values(rowCounter) = beta * y.values(rowCounter) + sum * alpha + yValues(rowCounter) = beta * yValues(rowCounter) + sum * alpha rowCounter += 1 } } else { // Scale vector first if `beta` is not equal to 0.0 - if (beta != 0.0){ + if (beta != 0.0) { scal(beta, y) } // Perform matrix-vector multiplication and add to y var colCounterForA = 0 - while (colCounterForA < nA){ + while (colCounterForA < nA) { var i = Acols(colCounterForA) val indEnd = Acols(colCounterForA + 1) - val xVal = x.values(colCounterForA) * alpha - while (i < indEnd){ + val xVal = xValues(colCounterForA) * alpha + while (i < indEnd) { val rowIndex = Arows(i) - y.values(rowIndex) += Avals(i) * xVal + yValues(rowIndex) += Avals(i) * xVal i += 1 } colCounterForA += 1