Skip to content

Commit

Permalink
addressed code review v1
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Jan 30, 2015
1 parent 25f083b commit b55ac5c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ class BlockMatrix(
private[mllib] var partitioner: GridPartitioner =
GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size)

private lazy val blockInfo = blocks.mapValues(block => (block.numRows, block.numCols)).cache()

/** Estimates the dimensions of the matrix. */
private def estimateDim(): Unit = {
val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
(blockRowIndex.toLong * rowsPerBlock + mat.numRows,
blockColIndex.toLong * colsPerBlock + mat.numCols)
val (rows, cols) = blockInfo.map { case ((blockRowIndex, blockColIndex), (m, n)) =>
(blockRowIndex.toLong * rowsPerBlock + m,
blockColIndex.toLong * colsPerBlock + n)
}.reduce { (x0, x1) =>
(math.max(x0._1, x1._1), math.max(x0._2, x1._2))
}
Expand All @@ -172,59 +174,31 @@ class BlockMatrix(
assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.")
}

def validate: Unit = {
def validate(): Unit = {
logDebug("Validating BlockMatrix...")
// check if the matrix is larger than the claimed dimensions
try {
estimateDim()
logDebug("BlockMatrix dimensions are okay...")
} catch {
case exc: AssertionError => throw new SparkException(s"$exc\nPlease instantiate a " +
s"new BlockMatrix with the correct dimensions.")
case e: Exception =>
throw new SparkException(s"${e.getMessage}\n${e.getStackTraceString}")
}
estimateDim()
logDebug("BlockMatrix dimensions are okay...")

// Check if there are multiple MatrixBlocks with the same index.
val indexCounts = blocks.countByKey().filter(p => p._2 > 1)
if (indexCounts.size > 50) {
throw new SparkException(s"There are ${indexCounts.size} MatrixBlocks with duplicate " +
s"indices. Please remove blocks with duplicate indices. You may call reduceByKey on " +
s"the underlying RDD and sum the duplicates. You may convert the matrices to Breeze " +
s"before summing them up.")
} else if (indexCounts.size > 0) {
var errorMsg = s"The following indices have more than one Matrix:\n"
indexCounts.foreach(index => errorMsg += s"Index: ${index._1}, count: ${index._2}\n")
errorMsg += "Please remove these blocks with duplicate indices. You may call " +
"reduceByKey on the underlying RDD and sum the duplicates. You may convert the " +
"matrices to Breeze before summing them up."
throw new SparkException(errorMsg)
val indexCounts = blockInfo.countByKey().foreach { case (key, cnt) =>
if (cnt > 1) {
throw new SparkException(s"There are MatrixBlocks with duplicate indices. Please remove " +
s"blocks with duplicate indices. You may call reduceByKey on the underlying RDD and " +
s"sum the duplicates. You may convert the matrices to Breeze before summing them up.")
}
}
logDebug("MatrixBlock indices are okay...")
// Check if each MatrixBlock (except edges) has the dimensions rowsPerBlock x colsPerBlock
// The first tuple is the index and the second tuple is the dimensions of the MatrixBlock
val blockDimensionMismatches = blocks.filter { case ((blockRowIndex, blockColIndex), block) =>
if ((blockRowIndex == numRowBlocks - 1) || (blockColIndex == numColBlocks - 1)) {
false // neglect edge blocks
} else {
// include it if the dimensions don't match
!(block.numRows == rowsPerBlock && block.numCols == colsPerBlock)
blockInfo.foreach { case ((blockRowIndex, blockColIndex), (m, n)) =>
if (m != rowsPerBlock || n != colsPerBlock) {
if (blockRowIndex != numRowBlocks - 1 || blockColIndex != numColBlocks - 1) {
throw new SparkException(s"There are MatrixBlocks with dimensions different than " +
s"rowsPerBlock: $rowsPerBlock, and colsPerBlock: $colsPerBlock. You may " +
s"use the repartition method to fix this issue.")
}
}
}.map { case ((blockRowIndex, blockColIndex), mat) =>
((blockRowIndex, blockColIndex), (mat.numRows, mat.numCols))
}
val dimensionMismatchCount = blockDimensionMismatches.count()
// Don't send whole list if there are more than 50 matrices with the wrong dimensions
if (dimensionMismatchCount > 50) {
throw new SparkException(s"There are $dimensionMismatchCount MatrixBlocks with dimensions " +
s"different than rowsPerBlock: $rowsPerBlock, and colsPerBlock: $colsPerBlock. You may " +
s"use the repartition method to fix this issue.")
} else if (dimensionMismatchCount > 0) {
val mismatches = blockDimensionMismatches.collect()
var errorMsg = s"The following MatrixBlocks have dimensions different than " +
s"(rowsPerBlock, colsPerBlock): ($rowsPerBlock, $colsPerBlock)\n"
mismatches.foreach(index => errorMsg += s"Index: ${index._1}, dimensions: ${index._2}\n")
errorMsg += "You may use the repartition method to fix this issue."
throw new SparkException(errorMsg)
}
logDebug("MatrixBlock dimensions are okay...")
logDebug("BlockMatrix is valid!")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

package org.apache.spark.mllib.linalg.distributed

import org.apache.spark.SparkException

import scala.util.Random

import breeze.linalg.{DenseMatrix => BDM}
import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix}
import org.apache.spark.mllib.util.MLlibTestSparkContext

Expand Down Expand Up @@ -151,8 +150,7 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {

test("validate") {
// No error
gridBasedMat.validate

gridBasedMat.validate()
// Wrong MatrixBlock dimensions
val blocks: Seq[((Int, Int), Matrix)] = Seq(
((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))),
Expand All @@ -164,28 +162,20 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
val wrongRowPerParts = new BlockMatrix(rdd, rowPerPart + 1, colPerPart)
val wrongColPerParts = new BlockMatrix(rdd, rowPerPart, colPerPart + 1)
intercept[SparkException] {
wrongRowPerParts.validate
}
intercept[SparkException] {
wrongColPerParts.validate
wrongRowPerParts.validate()
}
// Large number of mismatching MatrixBlock dimensions
val manyBlocks = for (i <- 0 until 60) yield ((i, 0), DenseMatrix.eye(1))
val manyWrongDims = new BlockMatrix(sc.parallelize(manyBlocks, numPartitions), 2, 2, 140, 4)
intercept[SparkException] {
manyWrongDims.validate
wrongColPerParts.validate()
}

// Wrong BlockMatrix dimensions
val wrongRowSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 4, 4)
intercept[SparkException] {
wrongRowSize.validate
wrongRowSize.validate()
}
val wrongColSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 5, 2)
intercept[SparkException] {
wrongColSize.validate
wrongColSize.validate()
}

// Duplicate indices
val duplicateBlocks: Seq[((Int, Int), Matrix)] = Seq(
((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))),
Expand All @@ -195,12 +185,7 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0))))
val dupMatrix = new BlockMatrix(sc.parallelize(duplicateBlocks, numPartitions), 2, 2)
intercept[SparkException] {
dupMatrix.validate
}
val duplicateBlocks2 = for (i <- 0 until 110) yield ((i / 2, i / 2), DenseMatrix.eye(1))
val largeDupMatrix = new BlockMatrix(sc.parallelize(duplicateBlocks2, numPartitions), 1, 1)
intercept[SparkException] {
largeDupMatrix.validate
dupMatrix.validate()
}
}
}

0 comments on commit b55ac5c

Please sign in to comment.