Skip to content

Commit

Permalink
preliminary changes addressing code review
Browse files Browse the repository at this point in the history
  • Loading branch information
brkyvz committed Jan 21, 2015
1 parent 1a63b20 commit eebbdf7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ private[mllib] class GridPartitioner(
override val numPartitions: Int) extends Partitioner {

/**
* Returns the index of the partition the SubMatrix belongs to.
* Returns the index of the partition the SubMatrix belongs to. Tries to achieve block wise
* partitioning.
*
* @param key The key for the SubMatrix. Can be its position in the grid (its column major index)
* or a tuple of three integers that are the final row index after the multiplication,
Expand All @@ -51,22 +52,25 @@ private[mllib] class GridPartitioner(
* @return The index of the partition, which the SubMatrix belongs to.
*/
override def getPartition(key: Any): Int = {
val sqrtPartition = math.round(math.sqrt(numPartitions)).toInt
// numPartitions may not be the square of a number, it can even be a prime number

key match {
case (rowIndex: Int, colIndex: Int) =>
Utils.nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
case (rowIndex: Int, innerIndex: Int, colIndex: Int) =>
Utils.nonNegativeMod(rowIndex + colIndex * numRowBlocks, numPartitions)
case (blockRowIndex: Int, blockColIndex: Int) =>
Utils.nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
case (blockRowIndex: Int, innerIndex: Int, blockColIndex: Int) =>
Utils.nonNegativeMod(blockRowIndex + blockColIndex * numRowBlocks, numPartitions)
case _ =>
throw new IllegalArgumentException("Unrecognized key")
throw new IllegalArgumentException(s"Unrecognized key. key: $key")
}
}

/** Checks whether the partitioners have the same characteristics */
override def equals(obj: Any): Boolean = {
obj match {
case r: GridPartitioner =>
(this.numPartitions == r.numPartitions) && (this.rowsPerBlock == r.rowsPerBlock) &&
(this.colsPerBlock == r.colsPerBlock)
(this.numRowBlocks == r.numRowBlocks) && (this.numColBlocks == r.numColBlocks)
(this.rowsPerBlock == r.rowsPerBlock) && (this.colsPerBlock == r.colsPerBlock)
case _ =>
false
}
Expand All @@ -85,7 +89,7 @@ class BlockMatrix(
val numColBlocks: Int,
val rdd: RDD[((Int, Int), Matrix)]) extends DistributedMatrix with Logging {

type SubMatrix = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), matrix)
private type SubMatrix = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), matrix)

/**
* Alternate constructor for BlockMatrix without the input of a partitioner. Will use a Grid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.mllib.linalg.distributed

import org.scalatest.FunSuite

import breeze.linalg.{DenseMatrix => BDM}

import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix}
Expand Down

0 comments on commit eebbdf7

Please sign in to comment.