Skip to content

Commit

Permalink
LightGBM pipeline serialization fix (dotnet#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored Mar 3, 2019
1 parent 8fd2aa8 commit b3fd4dc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionUtil.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnume
string labelColumn, string weightColumn)
{
Dictionary<string, object> props = null;
if (sweepParams == null)
if (sweepParams == null || !sweepParams.Any())
{
props = new Dictionary<string, object>();
}
Expand Down Expand Up @@ -185,11 +185,19 @@ public static ColumnInformation BuildColumnInfo(IDictionary<string, object> prop

private static ParameterSet BuildLightGbmParameterSet(IDictionary<string, object> props)
{
var parentProps = props.Where(p => p.Key != LightGbmTreeBoosterPropName);
var treeProps = ((CustomProperty)props[LightGbmTreeBoosterPropName]).Properties;
var allProps = parentProps.Union(treeProps);
var paramVals = allProps.Select(p => new StringParameterValue(p.Key, p.Value.ToString()));
return new ParameterSet(paramVals);
IEnumerable<IParameterValue> parameters;
if (props == null || !props.Any())
{
parameters = new List<IParameterValue>();
}
else
{
var parentProps = props.Where(p => p.Key != LightGbmTreeBoosterPropName);
var treeProps = ((CustomProperty)props[LightGbmTreeBoosterPropName]).Properties;
var allProps = parentProps.Union(treeProps);
parameters = allProps.Select(p => new StringParameterValue(p.Key, p.Value.ToString()));
}
return new ParameterSet(parameters);
}

private static void SetValue(FieldInfo fi, IComparable value, object obj, Type propertyType)
Expand Down
22 changes: 22 additions & 0 deletions src/Test/TrainerExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ public void BuildSdcaPipelineNode()
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
}

[TestMethod]
public void BuildLightGbmPipelineNodeDefaultParams()
{
var pipelineNode = new LightGbmBinaryExtension().CreatePipelineNode(
new List<SweepableParam>(),
new ColumnInformation());
var expectedJson = @"{
""Name"": ""LightGbmBinary"",
""NodeType"": ""Trainer"",
""InColumns"": [
""Features""
],
""OutColumns"": [
""Score""
],
""Properties"": {
""LabelColumn"": ""Label""
}
}";
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);
}

[TestMethod]
public void BuildPipelineNodeWithCustomColumns()
{
Expand Down

0 comments on commit b3fd4dc

Please sign in to comment.