Skip to content

Commit

Permalink
added scala suite tests
Browse files Browse the repository at this point in the history
added saveLocalFile to ModelExport trait
  • Loading branch information
selvinsource committed Oct 26, 2014
1 parent 226e184 commit 9bc494f
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@
package org.apache.spark.mllib.export

import java.io.OutputStream
import java.io.FileOutputStream
import java.io.File

trait ModelExport {

/**
* Write the exported model to the output stream specified
*/
def save(outputStream: OutputStream): Unit

/**
* Write the exported model to the local file specified
*/
def saveLocalFile(path: String): Unit = {
save(new FileOutputStream(new File(path)));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.export

import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.linalg.Vectors
import org.scalatest.FunSuite
import org.apache.spark.mllib.export.pmml.KMeansPMMLModelExport

class ModelExportFactorySuite extends FunSuite{

test("ModelExportFactory create KMeansPMMLModelExport when passing a KMeansModel") {

val clusterCenters = Array(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
)

val kmeansModel = new KMeansModel(clusterCenters);

val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)

assert(modelExport.isInstanceOf[KMeansPMMLModelExport])

}

test("ModelExportFactory generate IllegalArgumentException when passing an unsupported model") {

val invalidModel = new Object;

intercept[IllegalArgumentException] {
ModelExportFactory.createModelExport(invalidModel, ModelExportType.PMML)
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.export.pmml

import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.export.ModelExportFactory
import org.apache.spark.mllib.clustering.KMeansModel
import org.apache.spark.mllib.export.ModelExportType

class KMeansPMMLModelExportSuite extends FunSuite{

test("KMeansPMMLModelExport generate PMML format") {

val clusterCenters = Array(
Vectors.dense(1.0, 2.0, 6.0),
Vectors.dense(1.0, 3.0, 0.0),
Vectors.dense(1.0, 4.0, 6.0)
)

val kmeansModel = new KMeansModel(clusterCenters);

val modelExport = ModelExportFactory.createModelExport(kmeansModel, ModelExportType.PMML)

assert(modelExport.isInstanceOf[PMMLModelExport])

//TODO: asserts
//compare pmml fields to strings
modelExport.asInstanceOf[PMMLModelExport].getPmml()
//use document builder to load the xml generated and validated the notes by looking for them
modelExport.asInstanceOf[PMMLModelExport].save(System.out)
//saveLocalFile too??? search how to unit test file creating in java

}

}

0 comments on commit 9bc494f

Please sign in to comment.