Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine multiple tree ensemble models into a single tree ensemble #364

Merged
merged 3 commits into from
Jun 19, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Core/Prediction/ITrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public interface IModelCombiner<TModel, TPredictor>
TPredictor CombineModels(IEnumerable<TModel> models);
}

public delegate void SignatureModelCombiner(PredictionKind kind);

/// <summary>
/// Weakly typed interface for a trainer "session" that produces a predictor.
/// </summary>
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Microsoft.ML.FastTree.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<Compile Include="TreeEnsemble\Ensemble.cs" />
<Compile Include="TreeEnsemble\QuantileRegressionTree.cs" />
<Compile Include="TreeEnsemble\RegressionTree.cs" />
<Compile Include="TreeEnsemble\TreeEnsembleCombiner.cs" />
<Compile Include="Training\Applications\GradientWrappers.cs" />
<Compile Include="Training\Applications\ObjectiveFunction.cs" />
<Compile Include="Training\BaggingProvider.cs" />
Expand Down
32 changes: 27 additions & 5 deletions src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,22 +105,24 @@ public RegressionTree(byte[] buffer, ref int position)
LteChild = buffer.ToIntArray(ref position);
GtChild = buffer.ToIntArray(ref position);
SplitFeatures = buffer.ToIntArray(ref position);
int[] categoricalNodeIndices = buffer.ToIntArray(ref position);
CategoricalSplit = GetCategoricalSplitFromIndices(categoricalNodeIndices);
if (categoricalNodeIndices?.Length > 0)
byte[] categoricalSplitAsBytes = buffer.ToByteArray(ref position);
CategoricalSplit = categoricalSplitAsBytes.Select(b => b > 0).ToArray();
if (CategoricalSplit.Any(b => b))
Copy link
Member

@codemzs codemzs Jun 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for fixing this. We should add a test for this function. I believe this function is not used when saving the tree model to disk and reading it back in TrainTest or CV, hence it was not caught during testing. #Resolved

{
CategoricalSplitFeatures = new int[NumNodes][];
CategoricalSplitFeatureRanges = new int[NumNodes][];
foreach (var index in categoricalNodeIndices)
for (int index = 0; index < NumNodes; index++)
{
Contracts.Assert(CategoricalSplit[index]);
if (!CategoricalSplit[index])
continue;

CategoricalSplitFeatures[index] = buffer.ToIntArray(ref position);
CategoricalSplitFeatureRanges[index] = buffer.ToIntArray(ref position, 2);
}
}

Thresholds = buffer.ToUIntArray(ref position);
RawThresholds = buffer.ToFloatArray(ref position);
_splitGain = buffer.ToDoubleArray(ref position);
_gainPValue = buffer.ToDoubleArray(ref position);
_previousLeafValue = buffer.ToDoubleArray(ref position);
Expand All @@ -144,6 +146,23 @@ private bool[] GetCategoricalSplitFromIndices(int[] indices)
return categoricalSplit;
}

private bool[] GetCategoricalSplitFromBytes(byte[] indices)
{
bool[] categoricalSplit = new bool[NumNodes];
if (indices == null)
return categoricalSplit;

Contracts.Assert(indices.Length <= NumNodes);

foreach (int index in indices)
{
Contracts.Assert(index >= 0 && index < NumNodes);
categoricalSplit[index] = true;
}

return categoricalSplit;
}

/// <summary>
/// Create a Regression Tree object from raw tree contents.
/// </summary>
Expand Down Expand Up @@ -500,6 +519,7 @@ public virtual int SizeInBytes()
NumNodes * sizeof(int) +
CategoricalSplit.Length * sizeof(bool) +
Thresholds.SizeInBytes() +
RawThresholds.SizeInBytes() +
_splitGain.SizeInBytes() +
_gainPValue.SizeInBytes() +
_previousLeafValue.SizeInBytes() +
Expand All @@ -514,6 +534,7 @@ public virtual void ToByteArray(byte[] buffer, ref int position)
LteChild.ToByteArray(buffer, ref position);
GtChild.ToByteArray(buffer, ref position);
SplitFeatures.ToByteArray(buffer, ref position);
CategoricalSplit.Length.ToByteArray(buffer, ref position);
foreach (var split in CategoricalSplit)
Convert.ToByte(split).ToByteArray(buffer, ref position);

Expand All @@ -530,6 +551,7 @@ public virtual void ToByteArray(byte[] buffer, ref int position)
}

Thresholds.ToByteArray(buffer, ref position);
RawThresholds.ToByteArray(buffer, ref position);
_splitGain.ToByteArray(buffer, ref position);
_gainPValue.ToByteArray(buffer, ref position);
_previousLeafValue.ToByteArray(buffer, ref position);
Expand Down
116 changes: 116 additions & 0 deletions src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
//------------------------------------------------------------------------------
// <copyright company="Microsoft Corporation">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//------------------------------------------------------------------------------
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Jun 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need this one
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information. #Resolved


using System.Collections.Generic;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.FastTree.Internal;
using Microsoft.ML.Runtime.Internal.Calibration;

[assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")]

namespace Microsoft.ML.Runtime.FastTree.Internal
{
public sealed class TreeEnsembleCombiner : IModelCombiner<IPredictorProducing<float>, IPredictorProducing<float>>
{
private readonly IHost _host;
private readonly PredictionKind _kind;

public TreeEnsembleCombiner(IHostEnvironment env, PredictionKind kind)
{
_host = env.Register("TreeEnsembleCombiner");
switch (kind)
{
case PredictionKind.BinaryClassification:
case PredictionKind.Regression:
case PredictionKind.Ranking:
_kind = kind;
break;
default:
throw _host.ExceptUserArg(nameof(kind), "Tree ensembles can be either binary classifiers, regressors or rankers");
Copy link
Contributor

@TomFinley TomFinley Jun 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

binary classifiers, regressors or rankers [](start = 90, length = 41)

Would it be better for this to be an interpolated $ style string with {nameof(PredictionKind.BinaryClassificiation)}, and so on? Not sure. #Closed

}
}

public IPredictorProducing<float> CombineModels(IEnumerable<IPredictorProducing<float>> models)
{
_host.CheckValue(models, nameof(models));

var ensemble = new Ensemble();
int modelCount = 0;
int featureCount = -1;
bool binaryClassifier = false;
foreach (var model in models)
{
modelCount++;

var predictor = model;
_host.CheckValue(predictor, nameof(models), "One of the models is null");

var calibrated = predictor as CalibratedPredictorBase;
double paramA = 1;
if (calibrated != null)
{
_host.Check(calibrated.Calibrator is PlattCalibrator,
"Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
predictor = calibrated.SubPredictor;
paramA = -(calibrated.Calibrator as PlattCalibrator).ParamA;
}
var tree = predictor as FastTreePredictionWrapper;
if (tree == null)
throw _host.Except("Model is not a tree ensemble");
foreach (var t in tree.TrainedEnsemble.Trees)
{
var bytes = new byte[t.SizeInBytes()];
int position = -1;
t.ToByteArray(bytes, ref position);
position = -1;
var tNew = new RegressionTree(bytes, ref position);
if (paramA != 1)
{
for (int i = 0; i < tNew.NumLeaves; i++)
tNew.SetOutput(i, tNew.LeafValues[i] * paramA);
}
ensemble.AddTree(tNew);
}

if (modelCount == 1)
{
binaryClassifier = calibrated != null;
featureCount = tree.InputType.ValueCount;
}
else
{
_host.Check((calibrated != null) == binaryClassifier, "Ensemble contains both calibrated and uncalibrated models");
_host.Check(featureCount == tree.InputType.ValueCount, "Found models with different number of features");
}
}

var scale = 1 / (double)modelCount;

foreach (var t in ensemble.Trees)
{
for (int i = 0; i < t.NumLeaves; i++)
t.SetOutput(i, t.LeafValues[i] * scale);
}

switch (_kind)
{
case PredictionKind.BinaryClassification:
if (!binaryClassifier)
return new FastTreeBinaryPredictor(_host, ensemble, featureCount, null);

var cali = new PlattCalibrator(_host, -1, 0);
return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryPredictor(_host, ensemble, featureCount, null), cali);
case PredictionKind.Regression:
return new FastTreeRegressionPredictor(_host, ensemble, featureCount, null);
case PredictionKind.Ranking:
return new FastTreeRankingPredictor(_host, ensemble, featureCount, null);
default:
_host.Assert(false);
throw _host.ExceptNotSupp("PredictionKind can only be binary classification, regression or ranking");
}
Copy link
Contributor

@TomFinley TomFinley Jun 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is impossible to reach this state anyway due to the check in the constructor, perhaps this ought to just be a throw without the message. #Resolved

}
}
}
Loading