From d58e8d11bd5c2dfac70674f3a2bfa90536c8b3d4 Mon Sep 17 00:00:00 2001 From: Keren Fuentes Date: Thu, 21 May 2020 11:34:09 -0700 Subject: [PATCH] Adding support for MurmurHash KeyDataTypes (#5138) * merging * removed some outdated comments * update --- src/Microsoft.ML.Data/Transforms/Hashing.cs | 45 +++++++++++----- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 54 +++++++++++++++++-- 2 files changed, 82 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index 9ec6cd9bee..59c4db4881 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -1349,20 +1349,38 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder) private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, string dstVariable) { string castOutput; + string isGreaterThanZeroOutput = ""; OnnxNode castNode; OnnxNode murmurNode; + OnnxNode isZeroNode; var srcType = _srcTypes[iinfo].GetItemType(); - if (srcType is KeyDataViewType) - return false; if (_parent._columns[iinfo].Combine) return false; var opType = "MurmurHash3"; string murmurOutput = ctx.AddIntermediateVariable(_dstTypes[iinfo], "MurmurOutput"); - // Numeric input types are limited to those supported by the Onnxruntime MurmurHash operator, which currently only supports - // uints and ints. Thus, ulongs, longs, doubles and floats are not supported. + // Get zero value indeces + if (_srcTypes[iinfo] is KeyDataViewType) + { + var optType2 = "Cast"; + castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "CastOutput", true); + isZeroNode = ctx.CreateNode(optType2, srcVariable, castOutput, ctx.GetNodeName(optType2), ""); + isZeroNode.AddAttribute("to", NumberDataViewType.Int64.RawType); + + var zero = ctx.AddInitializer(0); + var isGreaterThanZeroOutputBool = ctx.AddIntermediateVariable(BooleanDataViewType.Instance, "isGreaterThanZeroOutputBool"); + optType2 = "Greater"; + ctx.CreateNode(optType2, new[] { castOutput, zero }, new[] { isGreaterThanZeroOutputBool }, ctx.GetNodeName(optType2), ""); + + isGreaterThanZeroOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "isGreaterThanZeroOutput"); + optType2 = "Cast"; + isZeroNode = ctx.CreateNode(optType2, isGreaterThanZeroOutputBool, isGreaterThanZeroOutput, ctx.GetNodeName(optType2), ""); + isZeroNode.AddAttribute("to", NumberDataViewType.Int64.RawType); + } + + // Since these numeric types are not supported by Onnxruntime, we cast them to UInt32. if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 || srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte || srcType == BooleanDataViewType.Instance) @@ -1372,15 +1390,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri castNode.AddAttribute("to", NumberDataViewType.UInt32.RawType); murmurNode = ctx.CreateNode(opType, castOutput, murmurOutput, ctx.GetNodeName(opType), "com.microsoft"); } - else if (srcType == NumberDataViewType.UInt32 || srcType == NumberDataViewType.Int32 || srcType == NumberDataViewType.UInt64 || - srcType == NumberDataViewType.Int64 || srcType == NumberDataViewType.Single || srcType == NumberDataViewType.Double || srcType == TextDataViewType.Instance) - - { - murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft"); - } else { - return false; + murmurNode = ctx.CreateNode(opType, srcVariable, murmurOutput, ctx.GetNodeName(opType), "com.microsoft"); } murmurNode.AddAttribute("positive", 1); @@ -1417,10 +1429,17 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri string one = ctx.AddInitializer(1); ctx.CreateNode(opType, new[] { castOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), ""); + string mulOutput = ctx.AddIntermediateVariable(vectorShape, "MulOutput"); + if (_srcTypes[iinfo] is KeyDataViewType) + { + opType = "Mul"; + ctx.CreateNode(opType, new[] { isGreaterThanZeroOutput, addOutput }, new[] { mulOutput }, ctx.GetNodeName(opType), ""); + } + opType = "Cast"; - var castNodeFinal = ctx.CreateNode(opType, addOutput, dstVariable, ctx.GetNodeName(opType), ""); + var input = (_srcTypes[iinfo] is KeyDataViewType) ? mulOutput: addOutput; + var castNodeFinal = ctx.CreateNode(opType, input, dstVariable, ctx.GetNodeName(opType), ""); castNodeFinal.AddAttribute("to", _dstTypes[iinfo].GetItemType().RawType); - return true; } diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 9f2482961f..cfabd9d1a0 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1195,11 +1195,57 @@ public void OneHotHashEncodingOnnxConversionTest() Done(); } + private class HashData + { + public uint Value { get; set; } + } + + [Fact] + public void MurmurHashKeyTest() + { + var mlContext = new MLContext(); + + var samples = new[] + { + new HashData {Value = 232}, + new HashData {Value = 42}, + new HashData {Value = 0}, + }; + + IDataView data = mlContext.Data.LoadFromEnumerable(samples); + + var hashEstimator = mlContext.Transforms.Conversion.MapValueToKey("Value").Append(mlContext.Transforms.Conversion.Hash(new[] + { + new HashingEstimator.ColumnOptions( + "ValueHashed", + "Value") + })); + var model = hashEstimator.Fit(data); + var transformedData = model.Transform(data); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data); + + var onnxFileName = "MurmurHashV2.onnx"; + var onnxTextName = "MurmurHashV2.txt"; + var onnxModelPath = GetOutputPath(onnxFileName); + var onnxTextPath = GetOutputPath(onnxTextName); + + SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(data); + var onnxResult = onnxTransformer.Transform(data); + CompareSelectedColumns("ValueHashed", "ValueHashed", transformedData, onnxResult); + } + Done(); + } + [Theory] [CombinatorialData] - // Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported. - // An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown - // when users try to convert the items mentioned above. public void MurmurHashScalarTest( [CombinatorialValues(DataKind.SByte, DataKind.Int16, DataKind.Int32, DataKind.Int64, DataKind.Byte, DataKind.UInt16, DataKind.UInt32, DataKind.UInt64, DataKind.Single, DataKind.Double, DataKind.String, DataKind.Boolean)] DataKind type, @@ -1252,7 +1298,7 @@ public void MurmurHashScalarTest( [Theory] [CombinatorialData] - // Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported. + // Due to lack of Onnxruntime support, OrderedHashing is not supported. // An InvalidOperationException stating that the onnx pipeline can't be fully converted is thrown // when users try to convert the items mentioned above. public void MurmurHashVectorTest(