diff --git a/src/Microsoft.ML.Auto/API/RunResult.cs b/src/Microsoft.ML.Auto/API/RunResult.cs index e729d6f2fc8..c20f4219205 100644 --- a/src/Microsoft.ML.Auto/API/RunResult.cs +++ b/src/Microsoft.ML.Auto/API/RunResult.cs @@ -14,6 +14,7 @@ public sealed class RunResult public Exception Exception { get; private set; } public string TrainerName { get; private set; } public int RuntimeInSeconds { get; private set; } + public IEstimator Estimator { get; private set; } internal Pipeline Pipeline { get; private set; } internal int PipelineInferenceTimeInSeconds { get; private set; } @@ -21,6 +22,7 @@ public sealed class RunResult internal RunResult( ITransformer model, T metrics, + IEstimator estimator, Pipeline pipeline, Exception exception, int runtimeInSeconds, @@ -29,6 +31,7 @@ internal RunResult( Model = model; ValidationMetrics = metrics; Pipeline = pipeline; + Estimator = estimator; Exception = exception; RuntimeInSeconds = runtimeInSeconds; PipelineInferenceTimeInSeconds = pipelineInferenceTimeInSeconds; diff --git a/src/Microsoft.ML.Auto/Experiment/Experiment.cs b/src/Microsoft.ML.Auto/Experiment/Experiment.cs index 4e2b48c4a6a..b4b830b13ad 100644 --- a/src/Microsoft.ML.Auto/Experiment/Experiment.cs +++ b/src/Microsoft.ML.Auto/Experiment/Experiment.cs @@ -97,8 +97,9 @@ public List> Execute() // evaluate pipeline runResult = ProcessPipeline(pipeline); - if (preprocessorTransform != null) + if (_preFeaturizers != null) { + runResult.Estimator = _preFeaturizers.Append(runResult.Estimator); runResult.Model = preprocessorTransform.Append(runResult.Model); } @@ -108,7 +109,7 @@ public List> Execute() catch (Exception ex) { WriteDebugLog(DebugStream.Exception, $"{pipeline?.Trainer} Crashed {ex}"); - runResult = new SuggestedPipelineResult(null, null, pipeline, -1, ex); + runResult = new SuggestedPipelineResult(null, null, null, pipeline, -1, ex); } var iterationResult = runResult.ToIterationResult(); @@ -149,19 +150,22 @@ private SuggestedPipelineResult ProcessPipeline(SuggestedPipeline pipeline) WriteDebugLog(DebugStream.RunResult, $"Processing pipeline {commandLineStr}."); + var pipelineEstimator = pipeline.ToEstimator(); + SuggestedPipelineResult runResult; + try { - var pipelineModel = pipeline.Fit(_trainData); + var pipelineModel = pipelineEstimator.Fit(_trainData); var scoredValidationData = pipelineModel.Transform(_validationData); var metrics = GetEvaluatedMetrics(scoredValidationData); var score = _metricsAgent.GetScore(metrics); - runResult = new SuggestedPipelineResult(metrics, pipelineModel, pipeline, score, null); + runResult = new SuggestedPipelineResult(metrics, pipelineEstimator, pipelineModel, pipeline, score, null); } catch(Exception ex) { WriteDebugLog(DebugStream.Exception, $"{pipeline.Trainer} Crashed {ex}"); - runResult = new SuggestedPipelineResult(null, null, pipeline, 0, ex); + runResult = new SuggestedPipelineResult(null, pipelineEstimator, null, pipeline, 0, ex); } // save pipeline run diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs index 9738ee13f86..14a622b8e99 100644 --- a/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs @@ -113,12 +113,6 @@ public IEstimator ToEstimator() return pipeline; } - public ITransformer Fit(IDataView trainData) - { - var estimator = ToEstimator(); - return estimator.Fit(trainData); - } - private void AddNormalizationTransforms() { // get learner diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs index 787d4e32ff3..8a5473a4583 100644 --- a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs @@ -33,23 +33,26 @@ public IRunResult ToRunResult(bool isMetricMaximizing) internal class SuggestedPipelineResult : SuggestedPipelineResult { public readonly T EvaluatedMetrics; + public IEstimator Estimator { get; set; } public ITransformer Model { get; set; } public Exception Exception { get; set; } public int RuntimeInSeconds { get; set; } public int PipelineInferenceTimeInSeconds { get; set; } - public SuggestedPipelineResult(T evaluatedMetrics, ITransformer model, SuggestedPipeline pipeline, double score, Exception exception) + public SuggestedPipelineResult(T evaluatedMetrics, IEstimator estimator, + ITransformer model, SuggestedPipeline pipeline, double score, Exception exception) : base(pipeline, score, exception == null) { EvaluatedMetrics = evaluatedMetrics; + Estimator = estimator; Model = model; Exception = exception; } public RunResult ToIterationResult() { - return new RunResult(Model, EvaluatedMetrics, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds); + return new RunResult(Model, EvaluatedMetrics, Estimator, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds); } } } diff --git a/src/Test/AutoFitTests.cs b/src/Test/AutoFitTests.cs index b2188d41be9..82f29f48e05 100644 --- a/src/Test/AutoFitTests.cs +++ b/src/Test/AutoFitTests.cs @@ -20,11 +20,12 @@ public void AutoFitBinaryTest() var trainData = textLoader.Load(dataPath); var validationData = context.Data.TakeRows(trainData, 100); trainData = context.Data.SkipRows(trainData, 100); - var result = context.Auto() + var results = context.Auto() .CreateBinaryClassificationExperiment(0) .Execute(trainData, validationData, new ColumnInformation() { LabelColumn = DatasetUtil.UciAdultLabel }); - - Assert.IsTrue(result.Max(i => i.ValidationMetrics.Accuracy) > 0.80); + var best = results.Best(); + Assert.IsTrue(best.ValidationMetrics.Accuracy > 0.80); + Assert.IsNotNull(best.Estimator); } [TestMethod] diff --git a/src/Test/RunResultTests.cs b/src/Test/RunResultTests.cs index 38aa123a478..166df42d72f 100644 --- a/src/Test/RunResultTests.cs +++ b/src/Test/RunResultTests.cs @@ -20,10 +20,10 @@ public void FindBestResultWithSomeNullMetrics() var runResults = new List>() { - new RunResult(null, null, null, null, 0, 0), - new RunResult(null, metrics1, null, null, 0, 0), - new RunResult(null, metrics2, null, null, 0, 0), - new RunResult(null, metrics3, null, null, 0, 0), + new RunResult(null, null, null, null, null, 0, 0), + new RunResult(null, metrics1, null, null, null, 0, 0), + new RunResult(null, metrics2, null, null, null, 0, 0), + new RunResult(null, metrics3, null, null, null, 0, 0), }; var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared); @@ -36,7 +36,7 @@ public void FindBestResultWithAllNullMetrics() { var runResults = new List>() { - new RunResult(null, null, null, null, 0, 0), + new RunResult(null, null, null, null, null, 0, 0), }; var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);