Skip to content

Commit

Permalink
Add estimator to public API iteration result (dotnet#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored Mar 3, 2019
1 parent 3f77e59 commit 8fd2aa8
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 21 deletions.
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Auto/API/RunResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ public sealed class RunResult<T>
public Exception Exception { get; private set; }
public string TrainerName { get; private set; }
public int RuntimeInSeconds { get; private set; }
public IEstimator<ITransformer> Estimator { get; private set; }

internal Pipeline Pipeline { get; private set; }
internal int PipelineInferenceTimeInSeconds { get; private set; }

internal RunResult(
ITransformer model,
T metrics,
IEstimator<ITransformer> estimator,
Pipeline pipeline,
Exception exception,
int runtimeInSeconds,
Expand All @@ -29,6 +31,7 @@ internal RunResult(
Model = model;
ValidationMetrics = metrics;
Pipeline = pipeline;
Estimator = estimator;
Exception = exception;
RuntimeInSeconds = runtimeInSeconds;
PipelineInferenceTimeInSeconds = pipelineInferenceTimeInSeconds;
Expand Down
14 changes: 9 additions & 5 deletions src/Microsoft.ML.Auto/Experiment/Experiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ public List<RunResult<T>> Execute()
// evaluate pipeline
runResult = ProcessPipeline(pipeline);

if (preprocessorTransform != null)
if (_preFeaturizers != null)
{
runResult.Estimator = _preFeaturizers.Append(runResult.Estimator);
runResult.Model = preprocessorTransform.Append(runResult.Model);
}

Expand All @@ -108,7 +109,7 @@ public List<RunResult<T>> Execute()
catch (Exception ex)
{
WriteDebugLog(DebugStream.Exception, $"{pipeline?.Trainer} Crashed {ex}");
runResult = new SuggestedPipelineResult<T>(null, null, pipeline, -1, ex);
runResult = new SuggestedPipelineResult<T>(null, null, null, pipeline, -1, ex);
}

var iterationResult = runResult.ToIterationResult();
Expand Down Expand Up @@ -149,19 +150,22 @@ private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)

WriteDebugLog(DebugStream.RunResult, $"Processing pipeline {commandLineStr}.");

var pipelineEstimator = pipeline.ToEstimator();

SuggestedPipelineResult<T> runResult;

try
{
var pipelineModel = pipeline.Fit(_trainData);
var pipelineModel = pipelineEstimator.Fit(_trainData);
var scoredValidationData = pipelineModel.Transform(_validationData);
var metrics = GetEvaluatedMetrics(scoredValidationData);
var score = _metricsAgent.GetScore(metrics);
runResult = new SuggestedPipelineResult<T>(metrics, pipelineModel, pipeline, score, null);
runResult = new SuggestedPipelineResult<T>(metrics, pipelineEstimator, pipelineModel, pipeline, score, null);
}
catch(Exception ex)
{
WriteDebugLog(DebugStream.Exception, $"{pipeline.Trainer} Crashed {ex}");
runResult = new SuggestedPipelineResult<T>(null, null, pipeline, 0, ex);
runResult = new SuggestedPipelineResult<T>(null, pipelineEstimator, null, pipeline, 0, ex);
}

// save pipeline run
Expand Down
6 changes: 0 additions & 6 deletions src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ public IEstimator<ITransformer> ToEstimator()
return pipeline;
}

public ITransformer Fit(IDataView trainData)
{
var estimator = ToEstimator();
return estimator.Fit(trainData);
}

private void AddNormalizationTransforms()
{
// get learner
Expand Down
7 changes: 5 additions & 2 deletions src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,26 @@ public IRunResult ToRunResult(bool isMetricMaximizing)
internal class SuggestedPipelineResult<T> : SuggestedPipelineResult
{
public readonly T EvaluatedMetrics;
public IEstimator<ITransformer> Estimator { get; set; }
public ITransformer Model { get; set; }
public Exception Exception { get; set; }

public int RuntimeInSeconds { get; set; }
public int PipelineInferenceTimeInSeconds { get; set; }

public SuggestedPipelineResult(T evaluatedMetrics, ITransformer model, SuggestedPipeline pipeline, double score, Exception exception)
public SuggestedPipelineResult(T evaluatedMetrics, IEstimator<ITransformer> estimator,
ITransformer model, SuggestedPipeline pipeline, double score, Exception exception)
: base(pipeline, score, exception == null)
{
EvaluatedMetrics = evaluatedMetrics;
Estimator = estimator;
Model = model;
Exception = exception;
}

public RunResult<T> ToIterationResult()
{
return new RunResult<T>(Model, EvaluatedMetrics, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds);
return new RunResult<T>(Model, EvaluatedMetrics, Estimator, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds);
}
}
}
7 changes: 4 additions & 3 deletions src/Test/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ public void AutoFitBinaryTest()
var trainData = textLoader.Load(dataPath);
var validationData = context.Data.TakeRows(trainData, 100);
trainData = context.Data.SkipRows(trainData, 100);
var result = context.Auto()
var results = context.Auto()
.CreateBinaryClassificationExperiment(0)
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = DatasetUtil.UciAdultLabel });

Assert.IsTrue(result.Max(i => i.ValidationMetrics.Accuracy) > 0.80);
var best = results.Best();
Assert.IsTrue(best.ValidationMetrics.Accuracy > 0.80);
Assert.IsNotNull(best.Estimator);
}

[TestMethod]
Expand Down
10 changes: 5 additions & 5 deletions src/Test/RunResultTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ public void FindBestResultWithSomeNullMetrics()

var runResults = new List<RunResult<RegressionMetrics>>()
{
new RunResult<RegressionMetrics>(null, null, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, metrics1, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, metrics2, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, metrics3, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, null, null, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, metrics1, null, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, metrics2, null, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, metrics3, null, null, null, 0, 0),
};

var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);
Expand All @@ -36,7 +36,7 @@ public void FindBestResultWithAllNullMetrics()
{
var runResults = new List<RunResult<RegressionMetrics>>()
{
new RunResult<RegressionMetrics>(null, null, null, null, 0, 0),
new RunResult<RegressionMetrics>(null, null, null, null, null, 0, 0),
};

var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);
Expand Down

0 comments on commit 8fd2aa8

Please sign in to comment.