From 41c663cd14247d44022f40cf2dce5977dbab282d Mon Sep 17 00:00:00 2001 From: srsaggam <41802116+srsaggam@users.noreply.github.com> Date: Wed, 30 Jan 2019 13:57:13 -0800 Subject: [PATCH] Set Nullable Auto params to null values (#50) * Added sequential grouping of columns * reverted the file * added auto params as null * change to the update fields method --- .../TrainerExtensions/SweepableParams.cs | 25 +++++++------ .../TrainerExtensions/TrainerExtensionUtil.cs | 25 +++---------- src/Test/TrainerExtensionsTests.cs | 36 +++++++++++++++---- 3 files changed, 48 insertions(+), 38 deletions(-) diff --git a/src/AutoML/TrainerExtensions/SweepableParams.cs b/src/AutoML/TrainerExtensions/SweepableParams.cs index c2daeabd7b..9890127327 100644 --- a/src/AutoML/TrainerExtensions/SweepableParams.cs +++ b/src/AutoML/TrainerExtensions/SweepableParams.cs @@ -31,14 +31,14 @@ private static IEnumerable BuildOnlineLinearArgsParams() private static IEnumerable BuildTreeArgsParams() { - return new SweepableParam[] - { + return new SweepableParam[] + { new SweepableLongParam("NumLeaves", 2, 128, isLogScale: true, stepSize: 4), new SweepableDiscreteParam("MinDocumentsInLeafs", new object[] { 1, 10, 50 }), new SweepableDiscreteParam("NumTrees", new object[] { 20, 100, 500 }), new SweepableFloatParam("LearningRates", 0.025f, 0.4f, isLogScale: true), new SweepableFloatParam("Shrinkage", 0.025f, 4f, isLogScale: true), - }; + }; } private static IEnumerable BuildLbfgsArgsParams() @@ -123,22 +123,24 @@ public static IEnumerable BuildPoissonRegressionParams() public static IEnumerable BuildSdcaParams() { return new SweepableParam[] { - new SweepableDiscreteParam("L2Const", new object[] { "", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }), - new SweepableDiscreteParam("L1Threshold", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f }), + new SweepableDiscreteParam("L2Const", new object[] { null, 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }), + new SweepableDiscreteParam("L1Threshold", new object[] { null, 0f, 0.25f, 0.5f, 0.75f, 1f }), new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f }), - new SweepableDiscreteParam("MaxIterations", new object[] { "", 10, 20, 100 }), + new SweepableDiscreteParam("MaxIterations", new object[] { null, 10, 20, 100 }), new SweepableDiscreteParam("Shuffle", null, isBool: true), new SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f }) }; } - public static IEnumerable BuildOrdinaryLeastSquaresParams() { + public static IEnumerable BuildOrdinaryLeastSquaresParams() + { return new SweepableParam[] { new SweepableDiscreteParam("L2Weight", new object[] { 1e-6f, 0.1f, 1f }) }; } - public static IEnumerable BuildSgdParams() { + public static IEnumerable BuildSgdParams() + { return new SweepableParam[] { new SweepableDiscreteParam("L2Weight", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f }), new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f }), @@ -147,12 +149,13 @@ public static IEnumerable BuildSgdParams() { }; } - public static IEnumerable BuildSymSgdParams() { + public static IEnumerable BuildSymSgdParams() + { return new SweepableParam[] { new SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 }), - new SweepableDiscreteParam("LearningRate", new object[] { "", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }), + new SweepableDiscreteParam("LearningRate", new object[] { null, 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }), new SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f }), - new SweepableDiscreteParam("UpdateFrequency", new object[] { "", 5, 20 }) + new SweepableDiscreteParam("UpdateFrequency", new object[] { null, 5, 20 }) }; } } diff --git a/src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs b/src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs index 09066bfd44..e34be3d329 100644 --- a/src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs +++ b/src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs @@ -77,7 +77,7 @@ public static Action CreateLightGbmArgsFunc(IEnumerable BuildPipelineNodeProps(TrainerName trainerName, IEnumerable sweepParams) { - if(trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti || + if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti || trainerName == TrainerName.LightGbmRegression) { return BuildLightGbmPipelineNodeProps(sweepParams); @@ -96,7 +96,7 @@ private static IDictionary BuildLightGbmPipelineNodeProps(IEnume var props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue()); props[LightGbmTreeBoosterPropName] = treeBoosterCustomProp; - + return props; } @@ -155,24 +155,9 @@ public static void UpdateFields(object obj, IEnumerable sweepPar { var optIndex = (int)dp.RawValue; //Contracts.Assert(0 <= optIndex && optIndex < dp.Options.Length, $"Options index out of range: {optIndex}"); - var option = dp.Options[optIndex].ToString().ToLower(); - - // Handle string values in sweep params - if (option == "auto" || option == "" || option == "< auto >") - { - //Check if nullable type, in which case 'null' is the auto value. - if (Nullable.GetUnderlyingType(fi.FieldType) != null) - fi.SetValue(obj, null); - else if (fi.FieldType.IsEnum) - { - // Check if there is an enum option named Auto - var enumDict = fi.FieldType.GetEnumValues().Cast() - .ToDictionary(v => Enum.GetName(fi.FieldType, v), v => v); - if (enumDict.ContainsKey("Auto")) - fi.SetValue(obj, enumDict["Auto"]); - } - } - else + var option = dp.Options[optIndex]; + + if (option != null) SetValue(fi, (IComparable)dp.Options[optIndex], obj, propType); } else diff --git a/src/Test/TrainerExtensionsTests.cs b/src/Test/TrainerExtensionsTests.cs index e00d075289..34f788e3f1 100644 --- a/src/Test/TrainerExtensionsTests.cs +++ b/src/Test/TrainerExtensionsTests.cs @@ -17,7 +17,7 @@ public void TrainerExtensionInstanceTests() { var context = new MLContext(); var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast(); - foreach(var trainerName in trainerNames) + foreach (var trainerName in trainerNames) { var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName); var instance = extension.CreateInstance(context, null); @@ -33,7 +33,7 @@ public void GetTrainersByMaxIterations() var tasks = new TaskKind[] { TaskKind.BinaryClassification, TaskKind.MulticlassClassification, TaskKind.Regression }; - foreach(var task in tasks) + foreach (var task in tasks) { var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10); var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50); @@ -52,7 +52,7 @@ public void GetTrainersByMaxIterations() public void BuildPipelineNodePropsLightGbm() { var sweepParams = SweepableParams.BuildLightGbmParams(); - foreach(var sweepParam in sweepParams) + foreach (var sweepParam in sweepParams) { sweepParam.RawValue = 1; } @@ -91,7 +91,7 @@ public void BuildPipelineNodePropsLightGbm() public void BuildPipelineNodePropsSdca() { var sweepParams = SweepableParams.BuildSdcaParams(); - foreach(var sweepParam in sweepParams) + foreach (var sweepParam in sweepParams) { sweepParam.RawValue = 1; } @@ -108,7 +108,29 @@ public void BuildPipelineNodePropsSdca() }"; Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps); } - + + [TestMethod] + public void BuildPipelineNodePropsSdcaWithNullValues() + { + var sweepParams = SweepableParams.BuildSdcaParams(); + foreach (var sweepParam in sweepParams) + { + sweepParam.RawValue = 0; + } + + var sdcaBinaryProps = TrainerExtensionUtil.BuildPipelineNodeProps(TrainerName.SdcaBinary, sweepParams); + var expectedJson = @" +{ + ""L2Const"": null, + ""L1Threshold"": null, + ""ConvergenceTolerance"": 0.001, + ""MaxIterations"": null, + ""Shuffle"": false, + ""BiasLearningRate"": 0.0 +}"; + Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps); + } + [TestMethod] public void BuildParameterSetLightGbm() { @@ -129,7 +151,7 @@ public void BuildParameterSetLightGbm() var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props); var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props); - foreach(var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams }) + foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams }) { Assert.AreEqual(4, paramSet.Count); Assert.AreEqual("1", paramSet["NumBoostRound"].ValueText); @@ -148,7 +170,7 @@ public void BuildParameterSetSdca() }; var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaBinary, props); - + Assert.AreEqual(1, sdcaParams.Count); Assert.AreEqual("1", sdcaParams["LearningRate"].ValueText); }