Skip to content

Commit

Permalink
Changes to use evaluator metrics names in PipelineSweeperSupportedMet…
Browse files Browse the repository at this point in the history
…rics. Made the private const strings in two classes public. (dotnet#276)
  • Loading branch information
george-microsoft authored and eerhardt committed Jul 27, 2018
1 parent 3c875a9 commit 63951b9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 105 deletions.
199 changes: 100 additions & 99 deletions src/Microsoft.ML.Data/EntryPoints/InputBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Reflection;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Data;
using Newtonsoft.Json.Linq;

namespace Microsoft.ML.Runtime.EntryPoints.JsonUtils
Expand Down Expand Up @@ -425,81 +426,81 @@ private static object ParseJsonValue(IExceptionContext ectx, Type type, Attribut
{
switch (dt)
{
case TlcModule.DataKind.Bool:
return value.Value<bool>();
case TlcModule.DataKind.String:
return value.Value<string>();
case TlcModule.DataKind.Char:
return value.Value<char>();
case TlcModule.DataKind.Enum:
if (!Enum.IsDefined(type, value.Value<string>()))
throw ectx.Except($"Requested value '{value.Value<string>()}' is not a member of the Enum type '{type.Name}'");
return Enum.Parse(type, value.Value<string>());
case TlcModule.DataKind.Float:
if (type == typeof(double))
return value.Value<double>();
else if (type == typeof(float))
return value.Value<float>();
else
{
case TlcModule.DataKind.Bool:
return value.Value<bool>();
case TlcModule.DataKind.String:
return value.Value<string>();
case TlcModule.DataKind.Char:
return value.Value<char>();
case TlcModule.DataKind.Enum:
if (!Enum.IsDefined(type, value.Value<string>()))
throw ectx.Except($"Requested value '{value.Value<string>()}' is not a member of the Enum type '{type.Name}'");
return Enum.Parse(type, value.Value<string>());
case TlcModule.DataKind.Float:
if (type == typeof(double))
return value.Value<double>();
else if (type == typeof(float))
return value.Value<float>();
else
{
ectx.Assert(false);
throw ectx.ExceptNotSupp();
}
case TlcModule.DataKind.Array:
var ja = value as JArray;
ectx.Check(ja != null, "Expected array value");
Func<IExceptionContext, JArray, Attributes, ModuleCatalog, object> makeArray = MakeArray<int>;
return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog);
case TlcModule.DataKind.Int:
if (type == typeof(long))
return value.Value<long>();
if (type == typeof(int))
return value.Value<int>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
}
case TlcModule.DataKind.Array:
var ja = value as JArray;
ectx.Check(ja != null, "Expected array value");
Func<IExceptionContext, JArray, Attributes, ModuleCatalog, object> makeArray = MakeArray<int>;
return Utils.MarshalInvoke(makeArray, type.GetElementType(), ectx, ja, attributes, catalog);
case TlcModule.DataKind.Int:
if (type == typeof(long))
return value.Value<long>();
if (type == typeof(int))
return value.Value<int>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.UInt:
if (type == typeof(ulong))
return value.Value<ulong>();
if (type == typeof(uint))
return value.Value<uint>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.Dictionary:
ectx.Check(value is JObject, "Expected object value");
Func<IExceptionContext, JObject, Attributes, ModuleCatalog, object> makeDict = MakeDictionary<int>;
return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog);
case TlcModule.DataKind.Component:
var jo = value as JObject;
ectx.Check(jo != null, "Expected object value");
// REVIEW: consider accepting strings alone.
var jName = jo[FieldNames.Name];
ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component.");
ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string.");
var name = jName.Value<string>();
ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject,
"Expected '" + FieldNames.Settings + "' field to be an object");
return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog);
default:
var settings = value as JObject;
ectx.Check(settings != null, "Expected object value");
var inputBuilder = new InputBuilder(ectx, type, catalog);

if (inputBuilder._fields.Length == 0)
throw ectx.Except($"Unsupported input type: {dt}");

if (settings != null)
{
foreach (var pair in settings)
case TlcModule.DataKind.UInt:
if (type == typeof(ulong))
return value.Value<ulong>();
if (type == typeof(uint))
return value.Value<uint>();
ectx.Assert(false);
throw ectx.ExceptNotSupp();
case TlcModule.DataKind.Dictionary:
ectx.Check(value is JObject, "Expected object value");
Func<IExceptionContext, JObject, Attributes, ModuleCatalog, object> makeDict = MakeDictionary<int>;
return Utils.MarshalInvoke(makeDict, type.GetGenericArguments()[1], ectx, (JObject)value, attributes, catalog);
case TlcModule.DataKind.Component:
var jo = value as JObject;
ectx.Check(jo != null, "Expected object value");
// REVIEW: consider accepting strings alone.
var jName = jo[FieldNames.Name];
ectx.Check(jName != null, "Field '" + FieldNames.Name + "' is required for component.");
ectx.Check(jName is JValue, "Expected '" + FieldNames.Name + "' field to be a string.");
var name = jName.Value<string>();
ectx.Check(jo[FieldNames.Settings] == null || jo[FieldNames.Settings] is JObject,
"Expected '" + FieldNames.Settings + "' field to be an object");
return GetComponentJson(ectx, type, name, jo[FieldNames.Settings] as JObject, catalog);
default:
var settings = value as JObject;
ectx.Check(settings != null, "Expected object value");
var inputBuilder = new InputBuilder(ectx, type, catalog);

if (inputBuilder._fields.Length == 0)
throw ectx.Except($"Unsupported input type: {dt}");

if (settings != null)
{
if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value))
throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'");
foreach (var pair in settings)
{
if (!inputBuilder.TrySetValueJson(pair.Key, pair.Value))
throw ectx.Except($"Unexpected value for component '{type}', field '{pair.Key}': '{pair.Value}'");
}
}
}

var missing = inputBuilder.GetMissingValues().ToArray();
if (missing.Length > 0)
throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}");
return inputBuilder.GetInstance();
var missing = inputBuilder.GetMissingValues().ToArray();
if (missing.Length > 0)
throw ectx.Except($"The following required inputs were not provided for component '{type}': {string.Join(", ", missing)}");
return inputBuilder.GetInstance();
}
}
catch (FormatException ex)
Expand Down Expand Up @@ -832,35 +833,35 @@ public static class SweepableDiscreteParam
public static class PipelineSweeperSupportedMetrics
{
public new static string ToString() => "SupportedMetric";
public const string Auc = "AUC";
public const string AccuracyMicro = "AccuracyMicro";
public const string AccuracyMacro = "AccuracyMacro";
public const string F1 = "F1";
public const string AuPrc = "AUPRC";
public const string TopKAccuracy = "TopKAccuracy";
public const string L1 = "L1";
public const string L2 = "L2";
public const string Rms = "RMS";
public const string LossFn = "LossFn";
public const string RSquared = "RSquared";
public const string LogLoss = "LogLoss";
public const string LogLossReduction = "LogLossReduction";
public const string Ndcg = "NDCG";
public const string Dcg = "DCG";
public const string PositivePrecision = "PositivePrecision";
public const string PositiveRecall = "PositiveRecall";
public const string NegativePrecision = "NegativePrecision";
public const string NegativeRecall = "NegativeRecall";
public const string DrAtK = "DrAtK";
public const string DrAtPFpr = "DrAtPFpr";
public const string DrAtNumPos = "DrAtNumPos";
public const string NumAnomalies = "NumAnomalies";
public const string ThreshAtK = "ThreshAtK";
public const string ThreshAtP = "ThreshAtP";
public const string ThreshAtNumPos = "ThreshAtNumPos";
public const string Nmi = "NMI";
public const string AvgMinScore = "AvgMinScore";
public const string Dbi = "DBI";
public const string Auc = BinaryClassifierEvaluator.Auc;
public const string AccuracyMicro = Data.MultiClassClassifierEvaluator.AccuracyMicro;
public const string AccuracyMacro = MultiClassClassifierEvaluator.AccuracyMacro;
public const string F1 = BinaryClassifierEvaluator.F1;
public const string AuPrc = BinaryClassifierEvaluator.AuPrc;
public const string TopKAccuracy = MultiClassClassifierEvaluator.TopKAccuracy;
public const string L1 = RegressionLossEvaluatorBase<MultiOutputRegressionEvaluator.Aggregator>.L1;
public const string L2 = RegressionLossEvaluatorBase<MultiOutputRegressionEvaluator.Aggregator>.L2;
public const string Rms = RegressionLossEvaluatorBase<MultiOutputRegressionEvaluator.Aggregator>.Rms;
public const string LossFn = RegressionLossEvaluatorBase<MultiOutputRegressionEvaluator.Aggregator>.Loss;
public const string RSquared = RegressionLossEvaluatorBase<MultiOutputRegressionEvaluator.Aggregator>.RSquared;
public const string LogLoss = BinaryClassifierEvaluator.LogLoss;
public const string LogLossReduction = BinaryClassifierEvaluator.LogLossReduction;
public const string Ndcg = RankerEvaluator.Ndcg;
public const string Dcg = RankerEvaluator.Dcg;
public const string PositivePrecision = BinaryClassifierEvaluator.PosPrecName;
public const string PositiveRecall = BinaryClassifierEvaluator.PosRecallName;
public const string NegativePrecision = BinaryClassifierEvaluator.NegPrecName;
public const string NegativeRecall = BinaryClassifierEvaluator.NegRecallName;
public const string DrAtK = AnomalyDetectionEvaluator.OverallMetrics.DrAtK;
public const string DrAtPFpr = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr;
public const string DrAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos;
public const string NumAnomalies = AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies;
public const string ThreshAtK = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK;
public const string ThreshAtP = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP;
public const string ThreshAtNumPos = AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos;
public const string Nmi = ClusteringEvaluator.Nmi;
public const string AvgMinScore = ClusteringEvaluator.AvgMinScore;
public const string Dbi = ClusteringEvaluator.Dbi;
}
}
}
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Evaluators/ClusteringEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public sealed class Arguments

public const string LoadName = "ClusteringEvaluator";

private const string Nmi = "NMI";
private const string AvgMinScore = "AvgMinScore";
private const string Dbi = "DBI";
public const string Nmi = "NMI";
public const string AvgMinScore = "AvgMinScore";
public const string Dbi = "DBI";

private readonly bool _calculateDbi;

Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ public sealed class Arguments

public const string LoadName = "RankingEvaluator";

private const string Ndcg = "NDCG";
private const string Dcg = "DCG";
private const string MaxDcg = "MaxDCG";
public const string Ndcg = "NDCG";
public const string Dcg = "DCG";
public const string MaxDcg = "MaxDCG";

/// <summary>
/// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
Expand Down

0 comments on commit 63951b9

Please sign in to comment.