Skip to content

Commit

Permalink
Examples #15, used Logging interface
Browse files Browse the repository at this point in the history
  • Loading branch information
vsuthichai committed Aug 8, 2016
1 parent 5ea4eb9 commit 5e69df4
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ sealed trait StopStrategy extends Serializable {
def shouldStop(trialsSoFar: Long, timeSinceFirstTrial: Duration): Boolean
}

// TODO
// TODO: Finish this.
case class StopContext[P, L](foo: Any)

class MaxTrialsStop(maxTrials: Long) extends StopStrategy {
Expand Down Expand Up @@ -45,6 +45,7 @@ object OptimizerFinishes extends StopStrategy {
override def shouldStop(trialsSoFar: Long, durationSinceFirstTrial: Duration): Boolean = false
}

// TODO: Finish this.
class StopStrategyPredicate[P, L](f: (StopContext[P, L]) => Boolean) {
def shouldStop(stopContext: StopContext[P, L]) = f(stopContext)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.eharmony.spotz.optimizer.grid
import com.eharmony.spotz.backend.{BackendFunctions, ParallelFunctions, SparkFunctions}
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.optimizer.AbstractOptimizer
import com.eharmony.spotz.util.{DurationUtils, Logger}
import com.eharmony.spotz.util.{DurationUtils, Logging}
import org.apache.spark.SparkContext
import org.joda.time.{DateTime, Duration}

Expand All @@ -18,9 +18,8 @@ abstract class GridSearch[P, L]
(paramSpace: Map[String, Iterable[_]], trialBatchSize: Int)
(implicit ord: Ordering[(P, L)], factory: Map[String, _] => P)
extends AbstractOptimizer[P, L, GridSearchResult[P, L]]
with BackendFunctions {

val LOG = Logger[this.type]()
with BackendFunctions
with Logging {

def minimize(objective: Objective[P, L], space: Map[String, Iterable[_]])
(implicit c: ClassTag[P], p: ClassTag[L]): GridSearchResult[P, L] = {
Expand Down Expand Up @@ -52,7 +51,7 @@ extends AbstractOptimizer[P, L, GridSearchResult[P, L]]
val endTime = DateTime.now()
val elapsedTime = new Duration(startTime, endTime)

LOG.info(s"Best point and loss after $trialsSoFar trials and ${DurationUtils.format(elapsedTime)} : $bestPointSoFar loss: $bestLossSoFar")
info(s"Best point and loss after $trialsSoFar trials and ${DurationUtils.format(elapsedTime)} : $bestPointSoFar loss: $bestLossSoFar")

trialsSoFar >= space.length match {
case true =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.eharmony.spotz.optimizer.random
import com.eharmony.spotz.backend.{BackendFunctions, ParallelFunctions, SparkFunctions}
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.optimizer._
import com.eharmony.spotz.util.{DurationUtils, Logger}
import com.eharmony.spotz.util.{DurationUtils, Logging}
import org.apache.spark.SparkContext
import org.joda.time.{DateTime, Duration}

Expand All @@ -18,9 +18,8 @@ abstract class RandomSearch[P, L]
(paramSpace: Map[String, RandomSampler[_]], stopStrategy: StopStrategy, trialBatchSize: Int, seed: Int = 0)
(implicit val ord: Ordering[(P, L)], factory: Map[String, _] => P)
extends AbstractOptimizer[P, L, RandomSearchResult[P, L]]
with BackendFunctions {

val LOG = Logger[this.type]()
with BackendFunctions
with Logging {

override def optimize(objective: Objective[P, L],
reducer: Reducer[(P, L)])
Expand All @@ -31,7 +30,8 @@ abstract class RandomSearch[P, L]
val firstLoss = objective(firstPoint)

// Last three arguments maintain the best point and loss and the trial count
randomSearch(objective, space, reducer, startTime, firstPoint, firstLoss, 1)
randomSearch(objective = objective, space = space, reducer = reducer, startTime = startTime,
bestPointSoFar = firstPoint, bestLossSoFar = firstLoss, trialsSoFar = 1)
}

@tailrec
Expand All @@ -46,15 +46,18 @@ abstract class RandomSearch[P, L]
val endTime = DateTime.now()
val elapsedTime = new Duration(startTime, endTime)

LOG.info(s"Best point and loss after $trialsSoFar trials and ${DurationUtils.format(elapsedTime)} : $bestPointSoFar loss: $bestLossSoFar")
info(s"Best point and loss after $trialsSoFar trials and ${DurationUtils.format(elapsedTime)} : $bestPointSoFar loss: $bestLossSoFar")

stopStrategy.shouldStop(trialsSoFar, elapsedTime) match {
case true =>
// Base case, End recursion
new RandomSearchResult[P, L](bestPointSoFar, bestLossSoFar, startTime, endTime, trialsSoFar, elapsedTime)

case false =>
val batchSize = scala.math.min(stopStrategy.getMaxTrials - trialsSoFar, trialBatchSize).toInt
val batchSize = scala.math.min(stopStrategy.getMaxTrials - trialsSoFar, trialBatchSize)
// TODO: Adaptive batch sizing
//val batchSize = nextBatchSize(None, elapsedTime, currentBatchSize, trialsSoFar, null, stopStrategy.getMaxTrials)

val (bestPoint, bestLoss) = reducer((bestPointSoFar, bestLossSoFar),
bestRandomPoint(trialsSoFar, batchSize, objective, space, reducer))

Expand Down
4 changes: 4 additions & 0 deletions examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-all</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.eharmony.spotz.examples
import com.eharmony.spotz.Preamble.Point
import com.eharmony.spotz.objective.Objective
import com.eharmony.spotz.optimizer.{OptimizerResult, StopStrategy, UniformDouble}
import org.joda.time.Duration

import scala.math._

Expand All @@ -22,7 +21,7 @@ class AckleyObjective extends Objective[Point, Double] {

trait AckleyExample {
val objective = new AckleyObjective
val stop = StopStrategy.stopAfterMaxDuration(Duration.standardSeconds(5))
val stop = StopStrategy.stopAfterMaxTrials(5000000)
val numBatchTrials = 500000

def apply(): OptimizerResult[Point, Double]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BraninObjective extends Objective[Point, Double] {

trait BraninExample {
val objective = new BraninObjective
val stop = StopStrategy.stopAfterMaxDuration(Duration.standardSeconds(5))
val stop = StopStrategy.stopAfterMaxTrials(5000000)
val numBatchTrials = 500000

def apply(): OptimizerResult[Point, Double]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.eharmony.spotz.examples

import com.eharmony.spotz.Preamble.Point
import org.junit.Assert._
import org.junit.{Ignore, Test}
import org.junit.Test

/**
* @author vsuthichai
Expand All @@ -10,41 +11,31 @@ class AckleyTest {
@Test
def testAckleyParRandomSearch() {
val result = AckleyParRandomSearch()
val point = result.bestPoint

assertEquals(result.bestLoss, 0.0, 0.003)
assertEquals(point.get[Double]("x"), 0.0, 0.001)
assertEquals(point.get[Double]("y"), 0.0, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

@Test
def testAckleyParGridSearch() {
val result = AckleyParGridSearch()
val point = result.bestPoint

assertEquals(result.bestLoss, 0.0, 0.003)
assertEquals(point.get[Double]("x"), 0.0, 0.001)
assertEquals(point.get[Double]("y"), 0.0, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

@Test
def testAckleySparkRandomSearch() {
val result = AckleySparkRandomSearch()
val point = result.bestPoint

assertEquals(result.bestLoss, 0.0, 0.003)
assertEquals(point.get[Double]("x"), 0.0, 0.001)
assertEquals(point.get[Double]("y"), 0.0, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

@Test
def testAckleySparkGridSearch() {
val result = AckleySparkGridSearch()
val point = result.bestPoint
checkAssertions(result.bestPoint, result.bestLoss)
}

assertEquals(result.bestLoss, 0.0, 0.003)
assertEquals(point.get[Double]("x"), 0.0, 0.001)
assertEquals(point.get[Double]("y"), 0.0, 0.001)
def checkAssertions(p: Point, l: Double) {
assertEquals(l, 0.0, 0.01)
assertEquals(0.0, p.get[Double]("x"), 0.01)
assertEquals(0.0, p.get[Double]("y"), 0.01)
}
}

Original file line number Diff line number Diff line change
@@ -1,33 +1,66 @@
package com.eharmony.spotz.examples

import com.eharmony.spotz.Preamble.Point
import org.junit.Assert._
import org.junit.Test

import scala.math.Pi

/**
* @author vsuthichai
*/
class BraninTest {
@Test
def testBraninParRandomSearch() {
val result = BraninParRandomSearch()
assertEquals(result.bestLoss, 0.397887, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

@Test
def testBraninParGridSearch() {
val result = BraninParGridSearch()
assertEquals(result.bestLoss, 0.397887, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

@Test
def testBraninSparkRandomSearch() {
val result = BraninSparkRandomSearch()
assertEquals(result.bestLoss, 0.397887, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

@Test
def testBraninSparkGridSearch() {
val result = BraninSparkGridSearch()
assertEquals(result.bestLoss, 0.397887, 0.001)
checkAssertions(result.bestPoint, result.bestLoss)
}

def checkAssertions(p: Point, l: Double) {
assertEquals(0.397887, l, 0.001)

try {
checkSolution1(p)
} catch {
case _: AssertionError =>
try {
checkSolution2(p)
} catch {
case _: AssertionError => checkSolution3(p)
}
}
}

def checkSolution1(p: Point) {
assertEquals(-Pi, p.get[Double]("x1"), 0.01)
assertEquals(12.275, p.get[Double]("x2"), 0.01)
}

def checkSolution2(p: Point) {
assertEquals(Pi, p.get[Double]("x1"), 0.01)
assertEquals(2.275, p.get[Double]("x2"), 0.01)
}

def checkSolution3(p: Point) {
assertEquals(9.42478, p.get[Double]("x1"), 0.01)
assertEquals(2.475, p.get[Double]("x2"), 0.01)
}
}
8 changes: 8 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
<slf4j.version>1.7.21</slf4j.version>
<logback.classic.version>1.1.7</logback.classic.version>
<junit.version>4.11</junit.version>
<hamcrest.version>1.3</hamcrest.version>
<scalatest.version>3.0.0</scalatest.version>
<scala.compat.version>${scala.major.version}</scala.compat.version>
</properties>
Expand Down Expand Up @@ -170,6 +171,13 @@
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>hamcrest-all</artifactId>
<version>${hamcrest.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</dependencyManagement>
Expand Down

0 comments on commit 5e69df4

Please sign in to comment.