From 63951b9d464baa567a606ca1d3f84a6919abd2c3 Mon Sep 17 00:00:00 2001 From: George Montanez <38017604+george-microsoft@users.noreply.github.com> Date: Thu, 31 May 2018 12:09:53 -0700 Subject: [PATCH] Changes to use evaluator metrics names in PipelineSweeperSupportedMetrics. Made the private const strings in two classes public. (#276) --- .../EntryPoints/InputBuilder.cs | 199 +++++++++--------- .../Evaluators/ClusteringEvaluator.cs | 6 +- .../Evaluators/RankerEvaluator.cs | 6 +- 3 files changed, 106 insertions(+), 105 deletions(-) diff --git a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs index 3ec23a01bf5..939d49893b7 100644 --- a/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs +++ b/src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs @@ -9,6 +9,7 @@ using System.Reflection; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Data; using Newtonsoft.Json.Linq; namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils @@ -425,81 +426,81 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut { switch (dt) { - case TlcModule.DataKind.Bool: - return value.Value(); - case TlcModule.DataKind.String: - return value.Value(); - case TlcModule.DataKind.Char: - return value.Value(); - case TlcModule.DataKind.Enum: - if (!Enum.IsDefined(type, value.Value())) - throw ectx.Except($"Requested value '{value.Value()}' is not a member of the Enum type '{type.Name}'"); - return Enum.Parse(type, value.Value()); - case TlcModule.DataKind.Float: - if (type == typeof(double)) - return value.Value(); - else if (type == typeof(float)) - return value.Value(); - else - { + case TlcModule.DataKind.Bool: + return value.Value(); + case TlcModule.DataKind.String: + return value.Value(); + case TlcModule.DataKind.Char: + return value.Value(); + case TlcModule.DataKind.Enum: + if (!Enum.IsDefined(type, value.Value())) + throw ectx.Except($"Requested value '{value.Value()}' is not a member of the Enum type '{type.Name}'"); + return Enum.Parse(type, value.Value()); + case TlcModule.DataKind.Float: + if (type == typeof(double)) + return value.Value(); + else if (type == typeof(float)) + return value.Value(); + else + { + ectx.Assert(false); + throw ectx.ExceptNotSupp(); + } + case TlcModule.DataKind.Array: + var ja = value as JArray; + ectx.Check(ja != null, "Expected array value"); + Func makeArray = MakeArray; + return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog); + case TlcModule.DataKind.Int: + if (type == typeof(long)) + return value.Value(); + if (type == typeof(int)) + return value.Value(); ectx.Assert(false); throw ectx.ExceptNotSupp(); - } - case TlcModule.DataKind.Array: - var ja = value as JArray; - ectx.Check(ja != null, "Expected array value"); - Func makeArray = MakeArray; - return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog); - case TlcModule.DataKind.Int: - if (type == typeof(long)) - return value.Value(); - if (type == typeof(int)) - return value.Value(); - ectx.Assert(false); - throw ectx.ExceptNotSupp(); - case TlcModule.DataKind.UInt: - if (type == typeof(ulong)) - return value.Value(); - if (type == typeof(uint)) - return value.Value(); - ectx.Assert(false); - throw ectx.ExceptNotSupp(); - case TlcModule.DataKind.Dictionary: - ectx.Check(value is JObject, "Expected object value"); - Func makeDict = MakeDictionary; - return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog); - case TlcModule.DataKind.Component: - var jo = value as JObject; - ectx.Check(jo != null, "Expected object value"); - // REVIEW: consider accepting strings alone. - var jName = jo[FieldNames.Name]; - ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); - ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); - var name = jName.Value(); - ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, - "Expected '" + FieldNames.Settings + "' field to be an object"); - return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog); - default: - var settings = value as JObject; - ectx.Check(settings != null, "Expected object value"); - var inputBuilder = new InputBuilder(ectx, type, catalog); - - if (inputBuilder._fields.Length == 0) - throw ectx.Except($"Unsupported input type: {dt}"); - - if (settings != null) - { - foreach (var pair in settings) + case TlcModule.DataKind.UInt: + if (type == typeof(ulong)) + return value.Value(); + if (type == typeof(uint)) + return value.Value(); + ectx.Assert(false); + throw ectx.ExceptNotSupp(); + case TlcModule.DataKind.Dictionary: + ectx.Check(value is JObject, "Expected object value"); + Func makeDict = MakeDictionary; + return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog); + case TlcModule.DataKind.Component: + var jo = value as JObject; + ectx.Check(jo != null, "Expected object value"); + // REVIEW: consider accepting strings alone. + var jName = jo[FieldNames.Name]; + ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component."); + ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string."); + var name = jName.Value(); + ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject, + "Expected '" + FieldNames.Settings + "' field to be an object"); + return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog); + default: + var settings = value as JObject; + ectx.Check(settings != null, "Expected object value"); + var inputBuilder = new InputBuilder(ectx, type, catalog); + + if (inputBuilder._fields.Length == 0) + throw ectx.Except($"Unsupported input type: {dt}"); + + if (settings != null) { - if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) - throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); + foreach (var pair in settings) + { + if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value)) + throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'"); + } } - } - var missing = inputBuilder.GetMissingValues().ToArray(); - if (missing.Length > 0) - throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); - return inputBuilder.GetInstance(); + var missing = inputBuilder.GetMissingValues().ToArray(); + if (missing.Length > 0) + throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}"); + return inputBuilder.GetInstance(); } } catch (FormatException ex) @@ -832,35 +833,35 @@ public static class SweepableDiscreteParam public static class PipelineSweeperSupportedMetrics { public new static string ToString() => "SupportedMetric"; - public const string Auc = "AUC"; - public const string AccuracyMicro = "AccuracyMicro"; - public const string AccuracyMacro = "AccuracyMacro"; - public const string F1 = "F1"; - public const string AuPrc = "AUPRC"; - public const string TopKAccuracy = "TopKAccuracy"; - public const string L1 = "L1"; - public const string L2 = "L2"; - public const string Rms = "RMS"; - public const string LossFn = "LossFn"; - public const string RSquared = "RSquared"; - public const string LogLoss = "LogLoss"; - public const string LogLossReduction = "LogLossReduction"; - public const string Ndcg = "NDCG"; - public const string Dcg = "DCG"; - public const string PositivePrecision = "PositivePrecision"; - public const string PositiveRecall = "PositiveRecall"; - public const string NegativePrecision = "NegativePrecision"; - public const string NegativeRecall = "NegativeRecall"; - public const string DrAtK = "DrAtK"; - public const string DrAtPFpr = "DrAtPFpr"; - public const string DrAtNumPos = "DrAtNumPos"; - public const string NumAnomalies = "NumAnomalies"; - public const string ThreshAtK = "ThreshAtK"; - public const string ThreshAtP = "ThreshAtP"; - public const string ThreshAtNumPos = "ThreshAtNumPos"; - public const string Nmi = "NMI"; - public const string AvgMinScore = "AvgMinScore"; - public const string Dbi = "DBI"; + public const string Auc = BinaryClassifierEvaluator.Auc; + public const string AccuracyMicro = Data.MultiClassClassifierEvaluator.AccuracyMicro; + public const string AccuracyMacro = MultiClassClassifierEvaluator.AccuracyMacro; + public const string F1 = BinaryClassifierEvaluator.F1; + public const string AuPrc = BinaryClassifierEvaluator.AuPrc; + public const string TopKAccuracy = MultiClassClassifierEvaluator.TopKAccuracy; + public const string L1 = RegressionLossEvaluatorBase.L1; + public const string L2 = RegressionLossEvaluatorBase.L2; + public const string Rms = RegressionLossEvaluatorBase.Rms; + public const string LossFn = RegressionLossEvaluatorBase.Loss; + public const string RSquared = RegressionLossEvaluatorBase.RSquared; + public const string LogLoss = BinaryClassifierEvaluator.LogLoss; + public const string LogLossReduction = BinaryClassifierEvaluator.LogLossReduction; + public const string Ndcg = RankerEvaluator.Ndcg; + public const string Dcg = RankerEvaluator.Dcg; + public const string PositivePrecision = BinaryClassifierEvaluator.PosPrecName; + public const string PositiveRecall = BinaryClassifierEvaluator.PosRecallName; + public const string NegativePrecision = BinaryClassifierEvaluator.NegPrecName; + public const string NegativeRecall = BinaryClassifierEvaluator.NegRecallName; + public const string DrAtK = AnomalyDetectionEvaluator.OverallMetrics.DrAtK; + public const string DrAtPFpr = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr; + public const string DrAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos; + public const string NumAnomalies = AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies; + public const string ThreshAtK = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK; + public const string ThreshAtP = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP; + public const string ThreshAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos; + public const string Nmi = ClusteringEvaluator.Nmi; + public const string AvgMinScore = ClusteringEvaluator.AvgMinScore; + public const string Dbi = ClusteringEvaluator.Dbi; } } } diff --git a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs index 9a8d0d9fd40..907760649f6 100644 --- a/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs @@ -38,9 +38,9 @@ public sealed class Arguments public const string LoadName = "ClusteringEvaluator"; - private const string Nmi = "NMI"; - private const string AvgMinScore = "AvgMinScore"; - private const string Dbi = "DBI"; + public const string Nmi = "NMI"; + public const string AvgMinScore = "AvgMinScore"; + public const string Dbi = "DBI"; private readonly bool _calculateDbi; diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index a383f835fdc..ae9c2a85943 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -43,9 +43,9 @@ public sealed class Arguments public const string LoadName = "RankingEvaluator"; - private const string Ndcg = "NDCG"; - private const string Dcg = "DCG"; - private const string MaxDcg = "MaxDCG"; + public const string Ndcg = "NDCG"; + public const string Dcg = "DCG"; + public const string MaxDcg = "MaxDCG"; /// /// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.