diff --git a/src/Microsoft.ML.AutoML/Experiment/Experiment.cs b/src/Microsoft.ML.AutoML/Experiment/Experiment.cs index c845dca14f..0d78a0cf63 100644 --- a/src/Microsoft.ML.AutoML/Experiment/Experiment.cs +++ b/src/Microsoft.ML.AutoML/Experiment/Experiment.cs @@ -63,7 +63,7 @@ public IList Execute() // get next pipeline var getPipelineStopwatch = Stopwatch.StartNew(); var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task, - _optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _trainerAllowList); + _optimizingMetricInfo.IsMaximizing, _experimentSettings.CacheBeforeTrainer, _logger, _trainerAllowList); var pipelineInferenceTimeInSeconds = getPipelineStopwatch.Elapsed.TotalSeconds; diff --git a/src/Microsoft.ML.AutoML/PipelineSuggesters/PipelineSuggester.cs b/src/Microsoft.ML.AutoML/PipelineSuggesters/PipelineSuggester.cs index 8a420cbb5d..711c7bd3ef 100644 --- a/src/Microsoft.ML.AutoML/PipelineSuggesters/PipelineSuggester.cs +++ b/src/Microsoft.ML.AutoML/PipelineSuggesters/PipelineSuggester.cs @@ -6,6 +6,8 @@ using System.Collections.Generic; using System.Linq; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; namespace Microsoft.ML.AutoML { @@ -17,10 +19,11 @@ public static Pipeline GetNextPipeline(MLContext context, IEnumerable history, DatasetColumnInfo[] columns, TaskKind task, + IChannel logger, bool isMaximizingMetric = true) { var inferredHistory = history.Select(r => SuggestedPipelineRunDetail.FromPipelineRunResult(context, r)); - var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric, CacheBeforeTrainer.Auto); + var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric, CacheBeforeTrainer.Auto, logger); return nextInferredPipeline?.ToPipeline(); } @@ -30,6 +33,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context, TaskKind task, bool isMaximizingMetric, CacheBeforeTrainer cacheBeforeTrainer, + IChannel logger, IEnumerable trainerAllowList = null) { var availableTrainers = RecipeInference.AllowedTrainers(context, task, @@ -64,7 +68,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context, do { // sample new hyperparameters for the learner - if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric)) + if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric, logger)) { // if unable to sample new hyperparameters for the learner // (ie SMAC returned 0 suggestions), break @@ -188,30 +192,42 @@ private static IValueGenerator[] ConvertToValueGenerators(IEnumerable - private static bool SampleHyperparameters(MLContext context, SuggestedTrainer trainer, IEnumerable history, bool isMaximizingMetric) + private static bool SampleHyperparameters(MLContext context, SuggestedTrainer trainer, + IEnumerable history, bool isMaximizingMetric, IChannel logger) { - var sps = ConvertToValueGenerators(trainer.SweepParams); - var sweeper = new SmacSweeper(context, - new SmacSweeper.Arguments + try + { + var sps = ConvertToValueGenerators(trainer.SweepParams); + var sweeper = new SmacSweeper(context, + new SmacSweeper.Arguments + { + SweptParameters = sps + }); + + IEnumerable historyToUse = history + .Where(r => r.RunSucceeded && r.Pipeline.Trainer.TrainerName == trainer.TrainerName && + r.Pipeline.Trainer.HyperParamSet != null && + r.Pipeline.Trainer.HyperParamSet.Any() && + FloatUtils.IsFinite(r.Score)); + + // get new set of hyperparameter values + var proposedParamSet = sweeper.ProposeSweeps(1, historyToUse.Select(h => h.ToRunResult(isMaximizingMetric))).FirstOrDefault(); + if (!proposedParamSet.Any()) { - SweptParameters = sps - }); + return false; + } - IEnumerable historyToUse = history - .Where(r => r.RunSucceeded && r.Pipeline.Trainer.TrainerName == trainer.TrainerName && r.Pipeline.Trainer.HyperParamSet != null && r.Pipeline.Trainer.HyperParamSet.Any()); + // associate proposed parameter set with trainer, so that smart hyperparameter + // sweepers (like KDO) can map them back. + trainer.SetHyperparamValues(proposedParamSet); - // get new set of hyperparameter values - var proposedParamSet = sweeper.ProposeSweeps(1, historyToUse.Select(h => h.ToRunResult(isMaximizingMetric))).First(); - if (!proposedParamSet.Any()) + return true; + } + catch (Exception ex) { - return false; + logger.Error($"SampleHyperparameters failed with exception: {ex}"); + throw; } - - // associate proposed parameter set with trainer, so that smart hyperparameter - // sweepers (like KDO) can map them back. - trainer.SetHyperparamValues(proposedParamSet); - - return true; } } } \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs b/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs index fb0a65bf5a..cec77327f1 100644 --- a/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs +++ b/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs @@ -9,6 +9,7 @@ using Newtonsoft.Json; using Microsoft.ML.TestFramework; using Xunit.Abstractions; +using Microsoft.ML.Runtime; namespace Microsoft.ML.AutoML.Test { @@ -27,7 +28,8 @@ public void GetNextPipeline() var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(context, uciAdult, new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }); // get next pipeline - var pipeline = PipelineSuggester.GetNextPipeline(context, new List(), columns, TaskKind.BinaryClassification); + var pipeline = PipelineSuggester.GetNextPipeline(context, new List(), columns, + TaskKind.BinaryClassification, ((IChannelProvider)context).Start("AutoMLTest")); // serialize & deserialize pipeline var serialized = JsonConvert.SerializeObject(pipeline); @@ -57,7 +59,7 @@ public void GetNextPipelineMock() for (var i = 0; i < maxIterations; i++) { // Get next pipeline - var pipeline = PipelineSuggester.GetNextPipeline(context, history, columns, task); + var pipeline = PipelineSuggester.GetNextPipeline(context, history, columns, task, ((IChannelProvider)context).Start("AutoMLTest")); if (pipeline == null) { break;