Skip to content

Commit

Permalink
Adding support for MurmurHash KeyDataTypes (dotnet#5138)
Browse files Browse the repository at this point in the history
* merging

* removed some outdated comments

* update
  • Loading branch information
Lynx1820 authored May 21, 2020
1 parent c023271 commit d58e8d1
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 17 deletions.
45 changes: 32 additions & 13 deletions src/Microsoft.ML.Data/Transforms/Hashing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1349,20 +1349,38 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder)
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable)
{
string castOutput;
string isGreaterThanZeroOutput = "";
OnnxNode castNode;
OnnxNode murmurNode;
OnnxNode isZeroNode;

var srcType = _srcTypes[iinfo].GetItemType();
if (srcType is KeyDataViewType)
return false;
if (_parent._columns[iinfo].Combine)
return false;

var opType = "MurmurHash3";
string murmurOutput = ctx.AddIntermediateVariable(_dstTypes[iinfo], "MurmurOutput");

// Numeric input types are limited to those supported by the Onnxruntime MurmurHash operator, which currently only supports
// uints and ints. Thus, ulongs, longs, doubles and floats are not supported.
// Get zero value indeces
if (_srcTypes[iinfo] is KeyDataViewType)
{
var optType2 = "Cast";
castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastOutput", true);
isZeroNode = ctx.CreateNode(optType2, srcVariable, castOutput, ctx.GetNodeName(optType2), "");
isZeroNode.AddAttribute("to", NumberDataViewType.Int64.RawType);

var zero = ctx.AddInitializer(0);
var isGreaterThanZeroOutputBool = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "isGreaterThanZeroOutputBool");
optType2 = "Greater";
ctx.CreateNode(optType2, new[] { castOutput, zero }, new[] { isGreaterThanZeroOutputBool }, ctx.GetNodeName(optType2), "");

isGreaterThanZeroOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "isGreaterThanZeroOutput");
optType2 = "Cast";
isZeroNode = ctx.CreateNode(optType2, isGreaterThanZeroOutputBool, isGreaterThanZeroOutput, ctx.GetNodeName(optType2), "");
isZeroNode.AddAttribute("to", NumberDataViewType.Int64.RawType);
}

// Since these numeric types are not supported by Onnxruntime, we cast them to UInt32.
if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 ||
srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte ||
srcType == BooleanDataViewType.Instance)
Expand All @@ -1372,15 +1390,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
castNode.AddAttribute("to", NumberDataViewType.UInt32.RawType);
murmurNode = ctx.CreateNode(opType, castOutput, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
}
else if (srcType == NumberDataViewType.UInt32 || srcType == NumberDataViewType.Int32 || srcType == NumberDataViewType.UInt64 ||
srcType == NumberDataViewType.Int64 || srcType == NumberDataViewType.Single || srcType == NumberDataViewType.Double || srcType == TextDataViewType.Instance)

{
murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
}
else
{
return false;
murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft");
}

murmurNode.AddAttribute("positive", 1);
Expand Down Expand Up @@ -1417,10 +1429,17 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
string one = ctx.AddInitializer(1);
ctx.CreateNode(opType, new[] { castOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");

string mulOutput = ctx.AddIntermediateVariable(vectorShape, "MulOutput");
if (_srcTypes[iinfo] is KeyDataViewType)
{
opType = "Mul";
ctx.CreateNode(opType, new[] { isGreaterThanZeroOutput, addOutput }, new[] { mulOutput }, ctx.GetNodeName(opType), "");
}

opType = "Cast";
var castNodeFinal = ctx.CreateNode(opType, addOutput, dstVariable, ctx.GetNodeName(opType), "");
var input = (_srcTypes[iinfo] is KeyDataViewType) ? mulOutput: addOutput;
var castNodeFinal = ctx.CreateNode(opType, input, dstVariable, ctx.GetNodeName(opType), "");
castNodeFinal.AddAttribute("to", _dstTypes[iinfo].GetItemType().RawType);

return true;
}

Expand Down
54 changes: 50 additions & 4 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1195,11 +1195,57 @@ public void OneHotHashEncodingOnnxConversionTest()
Done();
}

private class HashData
{
public uint Value { get; set; }
}

[Fact]
public void MurmurHashKeyTest()
{
var mlContext = new MLContext();

var samples = new[]
{
new HashData {Value = 232},
new HashData {Value = 42},
new HashData {Value = 0},
};

IDataView data = mlContext.Data.LoadFromEnumerable(samples);

var hashEstimator = mlContext.Transforms.Conversion.MapValueToKey("Value").Append(mlContext.Transforms.Conversion.Hash(new[]
{
new HashingEstimator.ColumnOptions(
"ValueHashed",
"Value")
}));
var model = hashEstimator.Fit(data);
var transformedData = model.Transform(data);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);

var onnxFileName = "MurmurHashV2.onnx";
var onnxTextName = "MurmurHashV2.txt";
var onnxModelPath = GetOutputPath(onnxFileName);
var onnxTextPath = GetOutputPath(onnxTextName);

SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath);

if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedColumns<uint>("ValueHashed", "ValueHashed", transformedData, onnxResult);
}
Done();
}

[Theory]
[CombinatorialData]
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
// when users try to convert the items mentioned above.
public void MurmurHashScalarTest(
[CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte,
DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type,
Expand Down Expand Up @@ -1252,7 +1298,7 @@ public void MurmurHashScalarTest(

[Theory]
[CombinatorialData]
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.
// Due to lack of Onnxruntime support, OrderedHashing is not supported.
// An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown
// when users try to convert the items mentioned above.
public void MurmurHashVectorTest(
Expand Down

0 comments on commit d58e8d1

Please sign in to comment.