Skip to content

Commit

Permalink
Minor fixes on SyncADMM.
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jul 20, 2014
1 parent ef569b3 commit db14f64
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 24 deletions.
1 change: 0 additions & 1 deletion conf/log4j.properties.template
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:

# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
11 changes: 9 additions & 2 deletions conf/slaves
Original file line number Diff line number Diff line change
@@ -1,2 +1,9 @@
# A Spark Worker will be started on each of the machines listed below.
localhost
ec2-54-203-223-46.us-west-2.compute.amazonaws.com
ec2-54-202-226-28.us-west-2.compute.amazonaws.com
ec2-54-203-142-34.us-west-2.compute.amazonaws.com
ec2-54-203-215-74.us-west-2.compute.amazonaws.com
ec2-54-184-210-61.us-west-2.compute.amazonaws.com
ec2-54-184-189-175.us-west-2.compute.amazonaws.com
ec2-54-203-167-251.us-west-2.compute.amazonaws.com
ec2-54-244-104-241.us-west-2.compute.amazonaws.com
ec2-54-184-164-7.us-west-2.compute.amazonaws.com
Original file line number Diff line number Diff line change
Expand Up @@ -15,57 +15,65 @@ import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater}



class DataLoaders(val sc: SparkContext) {
def loadBismark(filename: String): RDD[LabeledPoint] = {
object DataLoaders {
def loadBismark(sc: SparkContext, filename: String): RDD[LabeledPoint] = {
val data = sc.textFile(filename)
.filter(s => !s.isEmpty && s(0) == '{')
.map(s => s.split('\t'))
.map { case Array(x, y) =>
val features = x.stripPrefix("{").stripSuffix("}").split(',').map(xi => xi.toDouble)
val label = y.toDouble
val label = if (y.toDouble > 0) 1 else 0
LabeledPoint(label, new DenseVector(features))
}.cache()

data
}

def loadFlights(filename: String): RDD[LabeledPoint] = {
def makeDictionary(colId: Int, tbl: RDD[Array[String]]): Map[String, Int] = {
tbl.map(row => row(colId)).distinct.collect.zipWithIndex.toMap
}
def makeBinary(value: String, dict: Map[String, Int]): Array[Double] = {
val array = new Array[Double](dict.size)
array(dict(value)) = 1.0
array
}


def loadFlights(sc: SparkContext, filename: String): RDD[LabeledPoint] = {
val labels = Array("Year", "Month", "DayOfMonth", "DayOfWeek", "DepTime", "CRSDepTime", "ArrTime",
"CRSArrTime", "UniqueCarrier", "FlightNum", "TailNum", "ActualElapsedTime", "CRSElapsedTime",
"AirTime", "ArrDelay", "DepDelay", "Origin", "Dest", "Distance", "TaxiIn", "TaxiOut",
"Cancelled", "CancellationCode", "Diverted", "CarrierDelay", "WeatherDelay",
"NASDelay", "SecurityDelay", "LateAircraftDelay").zipWithIndex.toMap
println("Loading data")
val rawData = sc.textFile(filename, 128).
filter(s => !s.contains("Year")).
map(s => s.split(",")).cache()

def makeDictionary(colId: Int, tbl: RDD[Array[String]]): Map[String, Int] = {
tbl.map(row => row(colId)).distinct.collect.zipWithIndex.toMap
}
def makeBinary(value: String, dict: Map[String, Int]): Array[Double] = {
val array = new Array[Double](dict.size)
array(dict(value)) = 1.0
array
}
val carrierDict = makeDictionary(labels("UniqueCarrier"), rawData)
val flightNumDict = makeDictionary(labels("FlightNum"), rawData)
val tailNumDict = makeDictionary(labels("TailNum"), rawData)
val originDict = makeDictionary(labels("Origin"), rawData)
val destDict = makeDictionary(labels("Dest"), rawData)

val data = rawData.map { row =>
val firstFiveFeatures = (row.view(0, 5) ++ row.view(6, 7)).map(_.toDouble).toArray
val firstFiveFeatures = (row.view(0, 5) ++ row.view(6, 7)).map{ x =>
if(x == "NA") 0.0 else x.toDouble
}
val carrierFeatures = makeBinary(row(labels("UniqueCarrier")), carrierDict)
val flightFeatures = makeBinary(row(labels("FlightNum")), flightNumDict)
val tailNumFeatures = makeBinary(row(labels("TailNum")), tailNumDict)
val originFeatures = makeBinary(row(labels("Origin")), originDict)
val destFeatures = makeBinary(row(labels("Dest")), destDict)
val features: Array[Double] = firstFiveFeatures ++ carrierFeatures ++ flightFeatures ++
tailNumFeatures ++ originFeatures ++ destFeatures
val label = if (row(labels("ArrDelay")).toDouble > 0) 1.0 else 0.0
val features: Array[Double] = (firstFiveFeatures ++ carrierFeatures ++ flightFeatures ++
tailNumFeatures ++ originFeatures ++ destFeatures).toArray
val delay = row(labels("ArrDelay"))
val label = if (delay != "NA" && delay.toDouble > 0) 1.0 else 0.0
LabeledPoint(label, new DenseVector(features))
}.cache()

data.count
println("FINISHED LOADING SUCKA")
println(s"THIS MANY PLUSES SUCKA ${data.filter(x => x.label == 1).count/data.count.toDouble}")
data
}
}
Expand Down Expand Up @@ -149,14 +157,12 @@ object SynchronousADMMTests {

Logger.getRootLogger.setLevel(Level.WARN)

val dataloader = new DataLoaders(sc)

val examples = if(params.format == "lisbsvm") {
MLUtils.loadLibSVMFile(sc, params.input).cache()
} else if (params.format == "bismarck") {
dataloader.loadBismark(params.input).cache()
DataLoaders.loadBismark(sc, params.input).cache()
} else if (params.format == "flights") {
dataloader.loadFlights(params.input).cache()
DataLoaders.loadFlights(sc, params.input).cache()
} else {
throw new RuntimeException("F off")
}
Expand All @@ -174,6 +180,8 @@ object SynchronousADMMTests {

examples.unpersist(blocking = false)

println("STARTING SUCKA")

val updater = params.regType match {
case L1 => new L1Updater()
case L2 => new SquaredL2Updater()
Expand Down

0 comments on commit db14f64

Please sign in to comment.