Skip to content

Commit

Permalink
complete suite tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selvinsource committed Oct 29, 2014
1 parent 8e71b8d commit 1433b11
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,30 @@ class ModelExportFactorySuite extends FunSuite{

test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {

//arrange
val clusterCenters = Array(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
)

val kmeansModel = new KMeansModel(clusterCenters);

//act
val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)

//assert
assert(modelExport.isInstanceOf[KMeansPMMLModelExport])

}

test("ModelExportFactory throws IllegalArgumentException when passing an unsupported model") {
test("ModelExportFactory throw IllegalArgumentException when passing an unsupported model") {

//arrange
val invalidModel = new Object;

//assert
intercept[IllegalArgumentException] {
//act
ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,38 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.export.ModelExportFactory
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.export.ModelExportType
import org.dmg.pmml.ClusteringModel
import javax.xml.parsers.DocumentBuilderFactory
import java.io.ByteArrayOutputStream

class KMeansPMMLModelExportSuite extends FunSuite{

test("KMeansPMMLModelExport generate PMML format") {

//arrange model to test
val clusterCenters = Array(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
)

val kmeansModel = new KMeansModel(clusterCenters);

//act by exporting the model to the PMML format
val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)


//assert that the PMML format is as expected
assert(modelExport.isInstanceOf[PMMLModelExport])
var pmml = modelExport.asInstanceOf[PMMLModelExport].getPmml()
assert(pmml.getHeader().getDescription() === "k-means clustering")
//check that the number of fields match the single vector size
assert(pmml.getDataDictionary().getNumberOfFields() === clusterCenters(0).size)
//this verify that there is a model attached to the pmml object and the model is a clustering one
//it also verifies that the pmml model has the same number of clusters of the spark model
assert(pmml.getModels().get(0).asInstanceOf[ClusteringModel].getNumberOfClusters() === clusterCenters.size)

//TODO: asserts
//compare pmml fields to strings
modelExport.asInstanceOf[PMMLModelExport].getPmml()
//use document builder to load the xml generated and validated the notes by looking for them
modelExport.asInstanceOf[PMMLModelExport].save(System.out)
//saveLocalFile too??? search how to unit test file creating in java
//manual checking
//modelExport.asInstanceOf[PMMLModelExport].save(System.out)
//modelExport.asInstanceOf[PMMLModelExport].saveLocalFile("/tmp/kmeans.xml")

}

Expand Down

0 comments on commit 1433b11

Please sign in to comment.