diff --git a/cli/src/main/scala/com/ibm/sparktc/sparkbench/datageneration/mlgenerator/KMeansDataGen.scala b/cli/src/main/scala/com/ibm/sparktc/sparkbench/datageneration/mlgenerator/KMeansDataGen.scala index 117b103e..b090ee6f 100644 --- a/cli/src/main/scala/com/ibm/sparktc/sparkbench/datageneration/mlgenerator/KMeansDataGen.scala +++ b/cli/src/main/scala/com/ibm/sparktc/sparkbench/datageneration/mlgenerator/KMeansDataGen.scala @@ -67,7 +67,7 @@ case class KMeansDataGen( } val (convertTime, dataDF) = time { - val schemaString = data.first().indices.map(_.toString).mkString(" ") + val schemaString = data.first().indices.map(i => "c" + i.toString).mkString(" ") val fields = schemaString.split(" ").map(fieldName => StructField(fieldName, DoubleType, nullable = false)) val schema = StructType(fields) val rowRDD = data.map(arr => Row(arr:_*)) diff --git a/cli/src/test/scala/com/ibm/sparktc/sparkbench/datageneration/KMeansDataGenTest.scala b/cli/src/test/scala/com/ibm/sparktc/sparkbench/datageneration/KMeansDataGenTest.scala index 6aa0e489..0ab8efcf 100644 --- a/cli/src/test/scala/com/ibm/sparktc/sparkbench/datageneration/KMeansDataGenTest.scala +++ b/cli/src/test/scala/com/ibm/sparktc/sparkbench/datageneration/KMeansDataGenTest.scala @@ -21,40 +21,41 @@ import java.io.File import com.ibm.sparktc.sparkbench.datageneration.mlgenerator.KMeansDataGen import com.ibm.sparktc.sparkbench.testfixtures.{BuildAndTeardownData, SparkSessionProvider} -import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers} +import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers} import scala.io.Source -class KMeansDataGenTest extends FlatSpec with Matchers with BeforeAndAfterEach { +class KMeansDataGenTest extends FlatSpec with Matchers with BeforeAndAfterAll { val cool = new BuildAndTeardownData("kmeans-data-gen") - val fileName = s"${cool.sparkBenchTestFolder}/${java.util.UUID.randomUUID.toString}.csv" + val fileName = s"${cool.sparkBenchTestFolder}/${java.util.UUID.randomUUID.toString}" var file: File = _ - override def beforeEach() { + override def beforeAll() { cool.createFolders() - file = new File(fileName) } - override def afterEach() { + override def afterAll() { cool.deleteFolders() } - "KMeansDataGeneration" should "generate data correctly" in { + "KMeansDataGeneration" should "generate a csv correctly" in { + + val csvFile = s"$fileName.csv" val m = Map( "name" -> "kmeans", "rows" -> 10, "cols" -> 10, - "output" -> fileName + "output" -> csvFile ) val generator = KMeansDataGen(m) - generator.doWorkload(spark = SparkSessionProvider.spark) + file = new File(csvFile) val fileList = file.listFiles().toList.filter(_.getName.startsWith("part")) @@ -74,4 +75,34 @@ class KMeansDataGenTest extends FlatSpec with Matchers with BeforeAndAfterEach { */ length shouldBe generator.numRows + fileList.length } + + it should "generate an ORC file correctly" in { + val spark = SparkSessionProvider.spark + + val orcFile = s"$fileName.orc" + + val m = Map( + "name" -> "kmeans", + "rows" -> 10, + "cols" -> 10, + "output" -> orcFile + ) + + val generator = KMeansDataGen(m) + + generator.doWorkload(spark = spark) + + file = new File(orcFile) + + val list = file.listFiles().toList + val fileList = list.filter(_.getName.startsWith("part")) + + fileList.length should be > 0 + + println(s"reading file $orcFile") + + val fromDisk = spark.read.orc(orcFile) + val rows = fromDisk.count() + rows shouldBe 10 + } } \ No newline at end of file