Skip to content

Commit

Permalink
Set Nullable Auto params to null values (#50)
Browse files Browse the repository at this point in the history
* Added sequential grouping of columns

* reverted the file

* added auto params as null

* change to the update fields method
  • Loading branch information
srsaggam authored Jan 30, 2019
1 parent d254f4e commit 41c663c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 38 deletions.
25 changes: 14 additions & 11 deletions src/AutoML/TrainerExtensions/SweepableParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ private static IEnumerable<SweepableParam> BuildOnlineLinearArgsParams()

private static IEnumerable<SweepableParam> BuildTreeArgsParams()
{
return new SweepableParam[]
{
return new SweepableParam[]
{
new SweepableLongParam("NumLeaves", 2, 128, isLogScale: true, stepSize: 4),
new SweepableDiscreteParam("MinDocumentsInLeafs", new object[] { 1, 10, 50 }),
new SweepableDiscreteParam("NumTrees", new object[] { 20, 100, 500 }),
new SweepableFloatParam("LearningRates", 0.025f, 0.4f, isLogScale: true),
new SweepableFloatParam("Shrinkage", 0.025f, 4f, isLogScale: true),
};
};
}

private static IEnumerable<SweepableParam> BuildLbfgsArgsParams()
Expand Down Expand Up @@ -123,22 +123,24 @@ public static IEnumerable<SweepableParam> BuildPoissonRegressionParams()
public static IEnumerable<SweepableParam> BuildSdcaParams()
{
return new SweepableParam[] {
new SweepableDiscreteParam("L2Const", new object[] { "<Auto>", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }),
new SweepableDiscreteParam("L1Threshold", new object[] { "<Auto>", 0f, 0.25f, 0.5f, 0.75f, 1f }),
new SweepableDiscreteParam("L2Const", new object[] { null, 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }),
new SweepableDiscreteParam("L1Threshold", new object[] { null, 0f, 0.25f, 0.5f, 0.75f, 1f }),
new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f }),
new SweepableDiscreteParam("MaxIterations", new object[] { "<Auto>", 10, 20, 100 }),
new SweepableDiscreteParam("MaxIterations", new object[] { null, 10, 20, 100 }),
new SweepableDiscreteParam("Shuffle", null, isBool: true),
new SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })
};
}

public static IEnumerable<SweepableParam> BuildOrdinaryLeastSquaresParams() {
public static IEnumerable<SweepableParam> BuildOrdinaryLeastSquaresParams()
{
return new SweepableParam[] {
new SweepableDiscreteParam("L2Weight", new object[] { 1e-6f, 0.1f, 1f })
};
}

public static IEnumerable<SweepableParam> BuildSgdParams() {
public static IEnumerable<SweepableParam> BuildSgdParams()
{
return new SweepableParam[] {
new SweepableDiscreteParam("L2Weight", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f }),
new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f }),
Expand All @@ -147,12 +149,13 @@ public static IEnumerable<SweepableParam> BuildSgdParams() {
};
}

public static IEnumerable<SweepableParam> BuildSymSgdParams() {
public static IEnumerable<SweepableParam> BuildSymSgdParams()
{
return new SweepableParam[] {
new SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 }),
new SweepableDiscreteParam("LearningRate", new object[] { "<Auto>", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }),
new SweepableDiscreteParam("LearningRate", new object[] { null, 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }),
new SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f }),
new SweepableDiscreteParam("UpdateFrequency", new object[] { "<Auto>", 5, 20 })
new SweepableDiscreteParam("UpdateFrequency", new object[] { null, 5, 20 })
};
}
}
Expand Down
25 changes: 5 additions & 20 deletions src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public static Action<LightGbmArguments> CreateLightGbmArgsFunc(IEnumerable<Sweep

public static IDictionary<string, object> BuildPipelineNodeProps(TrainerName trainerName, IEnumerable<SweepableParam> sweepParams)
{
if(trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
trainerName == TrainerName.LightGbmRegression)
{
return BuildLightGbmPipelineNodeProps(sweepParams);
Expand All @@ -96,7 +96,7 @@ private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnume

var props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
props[LightGbmTreeBoosterPropName] = treeBoosterCustomProp;

return props;
}

Expand Down Expand Up @@ -155,24 +155,9 @@ public static void UpdateFields(object obj, IEnumerable<SweepableParam> sweepPar
{
var optIndex = (int)dp.RawValue;
//Contracts.Assert(0 <= optIndex && optIndex < dp.Options.Length, $"Options index out of range: {optIndex}");
var option = dp.Options[optIndex].ToString().ToLower();

// Handle <Auto> string values in sweep params
if (option == "auto" || option == "<auto>" || option == "< auto >")
{
//Check if nullable type, in which case 'null' is the auto value.
if (Nullable.GetUnderlyingType(fi.FieldType) != null)
fi.SetValue(obj, null);
else if (fi.FieldType.IsEnum)
{
// Check if there is an enum option named Auto
var enumDict = fi.FieldType.GetEnumValues().Cast<int>()
.ToDictionary(v => Enum.GetName(fi.FieldType, v), v => v);
if (enumDict.ContainsKey("Auto"))
fi.SetValue(obj, enumDict["Auto"]);
}
}
else
var option = dp.Options[optIndex];

if (option != null)
SetValue(fi, (IComparable)dp.Options[optIndex], obj, propType);
}
else
Expand Down
36 changes: 29 additions & 7 deletions src/Test/TrainerExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public void TrainerExtensionInstanceTests()
{
var context = new MLContext();
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
foreach(var trainerName in trainerNames)
foreach (var trainerName in trainerNames)
{
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
var instance = extension.CreateInstance(context, null);
Expand All @@ -33,7 +33,7 @@ public void GetTrainersByMaxIterations()
var tasks = new TaskKind[] { TaskKind.BinaryClassification,
TaskKind.MulticlassClassification, TaskKind.Regression };

foreach(var task in tasks)
foreach (var task in tasks)
{
var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10);
var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50);
Expand All @@ -52,7 +52,7 @@ public void GetTrainersByMaxIterations()
public void BuildPipelineNodePropsLightGbm()
{
var sweepParams = SweepableParams.BuildLightGbmParams();
foreach(var sweepParam in sweepParams)
foreach (var sweepParam in sweepParams)
{
sweepParam.RawValue = 1;
}
Expand Down Expand Up @@ -91,7 +91,7 @@ public void BuildPipelineNodePropsLightGbm()
public void BuildPipelineNodePropsSdca()
{
var sweepParams = SweepableParams.BuildSdcaParams();
foreach(var sweepParam in sweepParams)
foreach (var sweepParam in sweepParams)
{
sweepParam.RawValue = 1;
}
Expand All @@ -108,7 +108,29 @@ public void BuildPipelineNodePropsSdca()
}";
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
}


[TestMethod]
public void BuildPipelineNodePropsSdcaWithNullValues()
{
var sweepParams = SweepableParams.BuildSdcaParams();
foreach (var sweepParam in sweepParams)
{
sweepParam.RawValue = 0;
}

var sdcaBinaryProps = TrainerExtensionUtil.BuildPipelineNodeProps(TrainerName.SdcaBinary, sweepParams);
var expectedJson = @"
{
""L2Const"": null,
""L1Threshold"": null,
""ConvergenceTolerance"": 0.001,
""MaxIterations"": null,
""Shuffle"": false,
""BiasLearningRate"": 0.0
}";
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
}

[TestMethod]
public void BuildParameterSetLightGbm()
{
Expand All @@ -129,7 +151,7 @@ public void BuildParameterSetLightGbm()
var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props);
var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props);

foreach(var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
{
Assert.AreEqual(4, paramSet.Count);
Assert.AreEqual("1", paramSet["NumBoostRound"].ValueText);
Expand All @@ -148,7 +170,7 @@ public void BuildParameterSetSdca()
};

var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaBinary, props);

Assert.AreEqual(1, sdcaParams.Count);
Assert.AreEqual("1", sdcaParams["LearningRate"].ValueText);
}
Expand Down

0 comments on commit 41c663c

Please sign in to comment.