Skip to content

Commit

Permalink
make sure there is at most one spark context inside the same jvm
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 12, 2014
1 parent aa43a8d commit 913d48d
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@ import org.scalatest.FunSuite

import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.{SQLContext, SchemaRDD}

class LogisticRegressionSuite extends FunSuite with LocalSparkContext {

import sqlContext._
@transient var sqlContext: SQLContext = _
@transient var dataset: SchemaRDD = _

val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}

test("logistic regression") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
Expand All @@ -38,6 +46,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
}

test("logistic regression with setters") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
Expand All @@ -48,6 +58,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext {
}

test("logistic regression fit and transform with varargs") {
val sqlContext = this.sqlContext
import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@ import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.{SQLContext, SchemaRDD}

class CrossValidatorSuite extends FunSuite with LocalSparkContext {

import sqlContext._
@transient var dataset: SchemaRDD = _

val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
override def beforeAll(): Unit = {
super.beforeAll()
val sqlContext = new SQLContext(sc)
dataset = sqlContext.createSchemaRDD(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}

test("cross validation with logistic regression") {
val lr = new LogisticRegression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,26 @@

package org.apache.spark.mllib.util

import org.scalatest.{BeforeAndAfterAll, Suite}
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient val sc = new SparkContext("local", "test")
@transient lazy val sqlContext = new SQLContext(sc)
@transient var sc: SparkContext = _

override def beforeAll() {
super.beforeAll()
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("MLlibUnitTest")
sc = new SparkContext(conf)
}

override def afterAll() {
sc.stop()
if (sc != null) {
sc.stop()
}
super.afterAll()
}
}

0 comments on commit 913d48d

Please sign in to comment.