-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML-177][Native Bayes] Fix error when converting Vector to CSRNumeric…
…Table (#176) * 1. fix naiveBayes bug 2. add unit test for converting vector to CSRNumericTable Signed-off-by: minmingzhu <minming.zhu@intel.com> * Update OneDAL.scala * add unit test to test.sh Signed-off-by: minmingzhu <minming.zhu@intel.com> * fix comments Signed-off-by: minmingzhu <minming.zhu@intel.com> * modify code indent Signed-off-by: minmingzhu <minming.zhu@intel.com> * Update mllib-dal/src/main/scala/com/intel/oap/mllib/OneDAL.scala * Update mllib-dal/src/main/scala/com/intel/oap/mllib/OneDAL.scala Co-authored-by: Xiaochang Wu <xiaochang.wu@intel.com>
- Loading branch information
1 parent
0810769
commit 3a68664
Showing
3 changed files
with
55 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
mllib-dal/src/test/scala/org/apache/spark/ml/oneDALSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package org.apache.spark.ml | ||
|
||
import com.intel.oap.mllib.OneDAL | ||
import org.apache.spark.internal.Logging | ||
import org.apache.spark.ml.linalg.{Matrices, Vector, Vectors} | ||
import org.apache.spark.sql.Row | ||
|
||
class oneDALSuite extends FunctionsSuite with Logging { | ||
|
||
import testImplicits._ | ||
|
||
test("test sparse vector to CSRNumericTable") { | ||
val data = Seq( | ||
Vectors.sparse(3, Seq((0, 1.0), (1, 2.0), (2, 3.0))), | ||
Vectors.sparse(3, Seq((0, 10.0), (1, 20.0), (2, 30.0))), | ||
Vectors.sparse(3, Seq.empty), | ||
Vectors.sparse(3, Seq.empty), | ||
Vectors.sparse(3, Seq((0, 1.0), (1, 2.0))), | ||
Vectors.sparse(3, Seq((0, 10.0), (2, 20.0))), | ||
) | ||
val df = data.map(Tuple1.apply).toDF("features") | ||
df.show() | ||
val rowsRDD = df.rdd.map { | ||
case Row(features: Vector) => features | ||
} | ||
val results = rowsRDD.coalesce(1).mapPartitions { it: Iterator[Vector] => | ||
val vectors: Array[Vector] = it.toArray | ||
val numColumns = vectors(0).size | ||
val CSRNumericTable = { | ||
OneDAL.vectorsToSparseNumericTable(vectors, numColumns) | ||
} | ||
Iterator(CSRNumericTable.getCNumericTable) | ||
}.collect() | ||
val csr = OneDAL.makeNumericTable(results(0)) | ||
val resultMatrix = OneDAL.numericTableToMatrix(csr) | ||
val matrix = Matrices.fromVectors(data) | ||
|
||
assert((resultMatrix.toArray sameElements matrix.toArray) === true) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters