Skip to content

Commit

Permalink
Rebasing and fixing to reflect changes in master.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rogan Carr committed Mar 12, 2019
1 parent 437dcad commit 4d4e65b
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions test/Microsoft.ML.Functional.Tests/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ public void CompareTrainerEvaluations()
// Get the dataset.
var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename),
separatorChar: TestDatasets.Sentiment.fileSeparator,
hasHeader: TestDatasets.Sentiment.fileHasHeader,
hasHeader: TestDatasets.Sentiment.fileHasHeader,
allowQuoting: TestDatasets.Sentiment.allowQuoting);
var trainTestSplit = mlContext.BinaryClassification.TrainTestSplit(data);
var trainTestSplit = mlContext.Data.TrainTestSplit(data);
var trainData = trainTestSplit.TrainSet;
var testData = trainTestSplit.TestSet;

Expand Down Expand Up @@ -266,6 +266,7 @@ public void ContinueTrainingLogisticRegressionMulticlass()

// Create a training pipeline.
var featurizationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
.AppendCacheCheckpoint(mlContext);

var trainer = mlContext.MulticlassClassification.Trainers.LogisticRegression(
Expand Down Expand Up @@ -467,8 +468,7 @@ public void MetacomponentsFunctionAsExpectedOva()
var binaryClassificationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryclassificationTrainer))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryclassificationTrainer));

// Fit the binary classification pipeline.
var binaryClassificationModel = binaryClassificationPipeline.Fit(data);
Expand Down Expand Up @@ -503,40 +503,28 @@ public void MetacomponentsFunctionAsExpectedOva()
// Create a model training an OVA trainer with a ranking trainer.
var rankingTrainer = mlContext.Ranking.Trainers.FastTree(
new FastTreeRankingTrainer.Options { NumberOfTrees = 2, NumberOfThreads = 1, });
// Todo #2920: Make this fail somehow.
var rankingPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(rankingTrainer))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

// Fit the invalid pipeline.
// Todo #2920: Make this fail somehow.
var rankingModel = rankingPipeline.Fit(data);

// Transform the data
var rankingPredictions = rankingModel.Transform(data);

// Evaluate the model.
var rankingMetrics = mlContext.MulticlassClassification.Evaluate(rankingPredictions);
Assert.Throws<ArgumentOutOfRangeException>(() => rankingPipeline.Fit(data));

// Create a model training an OVA trainer with a regressor.
var regressionTrainer = mlContext.Regression.Trainers.PoissonRegression(
new PoissonRegression.Options { NumberOfIterations = 10, NumberOfThreads = 1, });
// Todo #2920: Make this fail somehow.
var regressionPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
.AppendCacheCheckpoint(mlContext)
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(regressionTrainer))
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

// Fit the invalid pipeline.
// Todo #2920: Make this fail somehow.
var regressionModel = regressionPipeline.Fit(data);

// Transform the data
var regressionPredictions = regressionModel.Transform(data);

// Evaluate the model.
var regressionMetrics = mlContext.MulticlassClassification.Evaluate(regressionPredictions);
Assert.Throws<ArgumentOutOfRangeException>(() => regressionPipeline.Fit(data));
}
}
}

0 comments on commit 4d4e65b

Please sign in to comment.