-
Notifications
You must be signed in to change notification settings - Fork 8
/
CensusIncomeDecisionTree.scala
255 lines (185 loc) · 7.71 KB
/
CensusIncomeDecisionTree.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
// Databricks notebook source
import org.apache.spark.sql.{DataFrame, functions}
def formatData(df: DataFrame, fields: Seq[String], continuousFieldIndexes: Seq[Int]): DataFrame = {
var data = df
// Trim leading spaces from data
for (colName <- data.columns)
data = data.withColumn(colName, functions.ltrim(functions.col(colName)))
// Assign column names
for (i <- fields.indices)
data = data.withColumnRenamed("_c" + i, fields(i))
data = data.withColumnRenamed("_c14", "label")
// Convert continuous values from string to double
for (i <- continuousFieldIndexes) {
data = data.withColumn(fields(i), functions.col(fields(i)).cast("double"))
}
// Remove '.' character from label
data = data.withColumn("label", functions.regexp_replace(functions.col("label"), "\\.", ""))
data
}
def showCategories(df: DataFrame, fields: Seq[String], categoricalFieldIndexes: Seq[Int]): Unit = {
for (i <- categoricalFieldIndexes) {
val colName = fields(i)
df.select(colName + "Indexed", colName).distinct().sort(colName + "Indexed").show(100)
}
}
// COMMAND ----------
val fields = Seq(
"age",
"workclass",
"fnlwgt",
"education",
"education-num",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"capital-gain",
"capital-loss",
"hours-per-week",
"native-country"
)
val categoricalFieldIndexes = Seq(1, 3, 5, 6, 7, 8, 9, 13)
val continuousFieldIndexes = Seq(0, 2, 4, 10, 11, 12)
// COMMAND ----------
// Create dataframe to hold census income training data
// Data retrieved from http://archive.ics.uci.edu/ml/datasets/Census+Income
val trainingUrl = "https://raw.githubusercontent.com/aosama/MachineLearningSamples/master/src/main/resources/adult.data"
val trainingContent = scala.io.Source.fromURL(trainingUrl).mkString
val trainingList = trainingContent.split("\n").filter(_ != "")
val trainingDs = sc.parallelize(trainingList).toDS()
var trainingData = spark.read.csv(trainingDs).cache
// COMMAND ----------
// Create dataframe to hold census income test data
// Data retrieved from http://archive.ics.uci.edu/ml/datasets/Census+Income
val testUrl = "https://raw.githubusercontent.com/aosama/MachineLearningSamples/master/src/main/resources/adult.test"
val testContent = scala.io.Source.fromURL(testUrl).mkString
val testList = testContent.split("\n").filter(_ != "")
val testDs = sc.parallelize(testList).toDS()
var testData = spark.read.csv(testDs).cache
// COMMAND ----------
// Format the data
trainingData = formatData(trainingData, fields, continuousFieldIndexes)
testData = formatData(testData, fields, continuousFieldIndexes)
// COMMAND ----------
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler}
// Create object to convert categorical values to index values
val categoricalIndexerArray =
for (i <- categoricalFieldIndexes)
yield new StringIndexer()
.setInputCol(fields(i))
.setOutputCol(fields(i) + "Indexed")
// Create object to index label values
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(trainingData)
// Create object to generate feature vector from categorical and continuous values
val vectorAssembler = new VectorAssembler()
.setInputCols((categoricalFieldIndexes.map(i => fields(i) + "Indexed") ++ continuousFieldIndexes.map(i => fields(i))).toArray)
.setOutputCol("features")
// Create object to convert indexed labels back to actual labels for predictions
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
// COMMAND ----------
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassifier
// Create decision tree
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("features")
.setMaxBins(50) // Since feature "native-country" contains 42 distinct values, need to increase max bins.
.setMaxDepth(6)
// Array of stages to run in pipeline
val indexerArray = Array(labelIndexer) ++ categoricalIndexerArray
val stageArray = indexerArray ++ Array(vectorAssembler, dt, labelConverter)
val pipeline = new Pipeline()
.setStages(stageArray)
// Train the model
val model = pipeline.fit(trainingData)
// Test the model
val predictions = model.transform(testData)
// COMMAND ----------
display(predictions.select("label", Seq("predictedLabel" ,"indexedLabel", "prediction") ++ fields:_*))
// COMMAND ----------
val wrongPredictions = predictions
.select("label", Seq("predictedLabel" ,"indexedLabel", "prediction") ++ fields:_*)
.where("indexedLabel != prediction")
display(wrongPredictions)
// COMMAND ----------
// Show the label and all the categorical features mapped to indexes
val indexedData = new Pipeline()
.setStages(indexerArray)
.fit(trainingData)
.transform(trainingData)
indexedData.select("indexedLabel", "label").distinct().sort("indexedLabel").show()
showCategories(indexedData, fields, categoricalFieldIndexes)
// COMMAND ----------
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.mllib.evaluation.MulticlassMetrics
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test error = ${1.0 - accuracy}\n")
val metrics = new MulticlassMetrics(
predictions.select("indexedLabel", "prediction")
.rdd.map(x => (x.getDouble(0), x.getDouble(1)))
)
println(s"Confusion matrix:\n ${metrics.confusionMatrix}\n")
val treeModel = model.stages(stageArray.length - 2).asInstanceOf[DecisionTreeClassificationModel]
// Print out the tree with actual column names for features
var treeModelString = treeModel.toDebugString
val featureFieldIndexes = categoricalFieldIndexes ++ continuousFieldIndexes
for (i <- featureFieldIndexes.indices)
treeModelString = treeModelString
.replace("feature " + i + " ", fields(featureFieldIndexes(i)) + " ")
println(s"Learned classification tree model:\n $treeModelString")
// COMMAND ----------
for (i <- featureFieldIndexes.indices)
println(s"feature " + i + " -> " + fields(featureFieldIndexes(i)))
// COMMAND ----------
display(treeModel)
// COMMAND ----------
display(testData.filter('age === 25))
// COMMAND ----------
testData.printSchema
// COMMAND ----------
import org.apache.spark.ml.linalg.Vector
val vectorElem = udf{ (x:Vector,i:Int) => x(i) }
val predictionsExpanded = predictions.withColumn("rawPrediction0",vectorElem('rawPrediction,functions.lit(0)))
.withColumn("rawPrediction1",vectorElem('rawPrediction,functions.lit(1)))
.withColumn("score0",vectorElem('probability,functions.lit(0)))
.withColumn("score1",vectorElem('probability,functions.lit(1)))
// COMMAND ----------
display(predictionsExpanded.orderBy($"age".asc))
// COMMAND ----------
val record = Seq((50,"Private",220931,"Bachelors",13,"Married-civ-spouse","Prof-specialty","Not-in-family","White","Male",10,0,43,"United-States")).toDF("age",
"workclass",
"fnlwgt",
"education",
"education-num",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"capital-gain",
"capital-loss",
"hours-per-week",
"native-country")
// COMMAND ----------
val singlePrediction = model.transform(record)
.withColumn("rawPrediction0",vectorElem('rawPrediction,functions.lit(0)))
.withColumn("rawPrediction1",vectorElem('rawPrediction,functions.lit(1)))
.withColumn("score0",vectorElem('probability,functions.lit(0)))
.withColumn("score1",vectorElem('probability,functions.lit(1)))
// COMMAND ----------
display(singlePrediction)
// COMMAND ----------
display(trainingData.groupBy('age).count.orderBy('age.asc))