Skip to content

Commit

Permalink
add VW classification sample
Browse files Browse the repository at this point in the history
  • Loading branch information
serena-ruan committed May 7, 2021
1 parent 5d670a3 commit aa28f11
Showing 1 changed file with 161 additions and 0 deletions.
161 changes: 161 additions & 0 deletions notebooks/samples/Vowpal Wabbit - Heart Disease Detection.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": 3
},
"orig_nbformat": 2
},
"nbformat": 4,
"nbformat_minor": 2,
"cells": [
{
"source": [
"## Heart Disease Detection with VowalWabbit Classifier"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"#### Read dataset"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset = spark.read.format(\"csv\")\\\n",
" .option(\"header\", True)\\\n",
" .load(\"wasbs://publicwasb@mmlspark.blob.core.windows.net/heart_disease_prediction_data.csv\")\n",
"# print dataset size\n",
"print(\"records read: \" + str(dataset.count()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# convert features to double type\n",
"from pyspark.sql.functions import col\n",
"from pyspark.sql.types import DoubleType\n",
"for colName in dataset.columns:\n",
" dataset = dataset.withColumn(colName, col(colName).cast(DoubleType()))\n",
"print(\"Schema: \")\n",
"dataset.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset.show(10, truncate=False)"
]
},
{
"source": [
"#### Split the dataset into train and test"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train, test = dataset.randomSplit([0.85, 0.15], seed=1)"
]
},
{
"source": [
"#### Use VowalWabbitFeaturizer to convert data features into vector"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmlspark.vw import VowpalWabbitFeaturizer\n",
"featurizer = VowpalWabbitFeaturizer(inputCols=dataset.columns[:-1], outputCol=\"features\")\n",
"train_data = featurizer.transform(train)[\"target\", \"features\"]\n",
"test_data = featurizer.transform(test)[\"target\", \"features\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_data.groupBy(\"target\").count().show()"
]
},
{
"source": [
"#### Model Training"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmlspark.vw import VowpalWabbitClassifier\n",
"model = VowpalWabbitClassifier(numPasses=20, labelCol=\"target\", featuresCol=\"features\").fit(train_data)"
]
},
{
"source": [
"#### Model Prediction"
],
"cell_type": "markdown",
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictions = model.transform(test_data)\n",
"predictions.limit(10).toPandas()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mmlspark.train import ComputeModelStatistics\n",
"metrics = ComputeModelStatistics(evaluationMetric='classification', labelCol='target', scoredLabelsCol='prediction').transform(predictions)\n",
"display(metrics)"
]
}
]
}

0 comments on commit aa28f11

Please sign in to comment.