Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up the SchemaDefinition class #2995

Merged
merged 7 commits into from
Mar 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 26 additions & 36 deletions src/Microsoft.ML.Data/Data/SchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,14 @@ public sealed class SchemaDefinition : List<SchemaDefinition.Column>
/// </summary>
public sealed class Column
{
private readonly Dictionary<string, AnnotationInfo> _annotations;
internal Dictionary<string, AnnotationInfo> Annotations { get { return _annotations; } }
internal Dictionary<string, AnnotationInfo> AnnotationInfos { get; }

/// <summary>
/// The name of the member the column is taken from. The API
/// requires this to not be null, and a valid name of a member of
/// the type for which we are creating a schema.
/// </summary>
public string MemberName { get; set; }
public string MemberName { get; }
/// <summary>
/// The name of the column that's created in the data view. If this
/// is null, the API uses the <see cref="MemberName"/>.
Expand All @@ -223,34 +222,21 @@ public sealed class Column
/// </summary>
public DataViewType ColumnType { get; set; }

/// <summary>
/// Whether the column is a computed type.
/// </summary>
public bool IsComputed { get { return Generator != null; } }

/// <summary>
/// The generator function. if the column is computed.
/// </summary>
public Delegate Generator { get; set; }
internal Delegate Generator { get; set; }

public Type ReturnType => Generator?.GetMethodInfo().GetParameters().LastOrDefault().ParameterType.GetElementType();
internal Type ReturnType => Generator?.GetMethodInfo().GetParameters().LastOrDefault().ParameterType.GetElementType();

public Column(IExceptionContext ectx, string memberName, DataViewType columnType,
string columnName = null, IEnumerable<AnnotationInfo> annotationInfos = null, Delegate generator = null)
internal Column(string memberName, DataViewType columnType,
string columnName = null)
{
ectx.CheckNonEmpty(memberName, nameof(memberName));
Contracts.CheckNonEmpty(memberName, nameof(memberName));
MemberName = memberName;
ColumnName = columnName ?? memberName;
ColumnType = columnType;
Generator = generator;
_annotations = annotationInfos != null ?
annotationInfos.ToDictionary(m => m.Kind, m => m)
: new Dictionary<string, AnnotationInfo>();
}

public Column()
{
_annotations = _annotations ?? new Dictionary<string, AnnotationInfo>();
AnnotationInfos = new Dictionary<string, AnnotationInfo>();
}

/// <summary>
Expand All @@ -262,38 +248,42 @@ public Column()
/// <param name="kind">The string identifier of the annotation.</param>
/// <param name="value">Value of annotation.</param>
/// <param name="annotationType">Type of value.</param>
public void AddAnnotation<T>(string kind, T value, DataViewType annotationType = null)
public void AddAnnotation<T>(string kind, T value, DataViewType annotationType)
{
if (_annotations.ContainsKey(kind))
Contracts.CheckValue(kind, nameof(kind));
Contracts.CheckValue(annotationType, nameof(annotationType));

if (AnnotationInfos.ContainsKey(kind))
throw Contracts.Except("Column already contains an annotation of this kind.");
_annotations[kind] = new AnnotationInfo<T>(kind, value, annotationType);
AnnotationInfos[kind] = new AnnotationInfo<T>(kind, value, annotationType);
}

/// <summary>
/// Remove annotation from the column if it exists.
/// </summary>
/// <param name="kind">The string identifier of the annotation.</param>
public void RemoveAnnotation(string kind)
internal void AddAnnotation(string kind, AnnotationInfo info)
{
if (_annotations.ContainsKey(kind))
_annotations.Remove(kind);
throw Contracts.Except("Column does not contain an annotation of kind: " + kind);
AnnotationInfos[kind] = info;
}

/// <summary>
/// Returns annotations kind and type associated with this column.
/// </summary>
/// <returns>A dictionary with the kind of the annotation as the key, and the
/// annotation type as the associated value.</returns>
public IEnumerable<KeyValuePair<string, DataViewType>> GetAnnotationTypes
public DataViewSchema.Annotations Annotations
{
get
{
return Annotations.Select(x => new KeyValuePair<string, DataViewType>(x.Key, x.Value.AnnotationType));
var builder = new DataViewSchema.Annotations.Builder();
foreach (var kvp in AnnotationInfos)
builder.Add(kvp.Key, kvp.Value.AnnotationType, kvp.Value.GetGetterDelegate());
return builder.ToAnnotations();
}
}
}

private SchemaDefinition()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SchemaDefinition [](start = 16, length = 16)

So, why is this private? I'm thinking about how I'd like to use it. I have my class, I create a new schema definition (but empty), then I populate the mapping. Do I have any other way to create an empty one of these guys?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that not only is there no way to create an empty one, but there is no longer a way to add a new column to it, since the column constructor is now internal.

Are we sure there are no scenarios that need to do this?

{
}

/// <summary>
/// Get or set the column definition by column name.
/// If there's no such column:
Expand Down Expand Up @@ -430,7 +420,7 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
else
columnType = itemType;

cols.Add(new Column() { MemberName = memberInfo.Name, ColumnName = name, ColumnType = columnType });
cols.Add(new Column(memberInfo.Name, columnType, name));
}
return cols;
}
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ internal static SchemaDefinition GetSchemaDefinition<TRow>(IHostEnvironment env,
foreach (var annotation in annotations.Schema)
{
var info = Utils.MarshalInvoke(GetAnnotationInfo<int>, annotation.Type.RawType, annotation.Name, annotations);
schemaDefinitionCol.Annotations.Add(annotation.Name, info);
schemaDefinitionCol.AddAnnotation(annotation.Name , info);
}
}
}
Expand Down Expand Up @@ -797,7 +797,7 @@ internal static DataViewSchema.DetachedColumn[] GetSchemaColumns(InternalSchemaD
/// <summary>
/// A single instance of annotation information, associated with a column.
/// </summary>
public abstract partial class AnnotationInfo
internal abstract partial class AnnotationInfo
{
/// <summary>
/// The type of the annotation.
Expand Down Expand Up @@ -826,7 +826,7 @@ private protected AnnotationInfo(string kind, DataViewType annotationType)
/// Strongly-typed version of <see cref="AnnotationInfo"/>, that contains the actual value of the annotation.
/// </summary>
/// <typeparam name="T">Type of the annotation value.</typeparam>
public sealed class AnnotationInfo<T> : AnnotationInfo
internal sealed class AnnotationInfo<T> : AnnotationInfo
{
public readonly T Value;

Expand Down
13 changes: 6 additions & 7 deletions src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ public class Column
public readonly DataViewType ColumnType;
public readonly bool IsComputed;
public readonly Delegate Generator;
private readonly Dictionary<string, AnnotationInfo> _annotations;
public Dictionary<string, AnnotationInfo> Annotations { get { return _annotations; } }
public Dictionary<string, AnnotationInfo> Annotations { get; }
public Type ComputedReturnType { get { return ReturnParameterInfo.ParameterType.GetElementType(); } }
public Type FieldOrPropertyType => (MemberInfo is FieldInfo) ? (MemberInfo as FieldInfo).FieldType : (MemberInfo as PropertyInfo).PropertyType;
public Type OutputType => IsComputed ? ComputedReturnType : FieldOrPropertyType;
Expand Down Expand Up @@ -74,7 +73,7 @@ private Column(string columnName, DataViewType columnType, MemberInfo memberInfo
ColumnType = columnType;
IsComputed = generator != null;
Generator = generator;
_annotations = metadataInfos == null ? new Dictionary<string, AnnotationInfo>()
Annotations = metadataInfos == null ? new Dictionary<string, AnnotationInfo>()
: metadataInfos.ToDictionary(entry => entry.Key, entry => entry.Value);

AssertRep();
Expand Down Expand Up @@ -218,7 +217,7 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
Type dataItemType;
MemberInfo memberInfo = null;

if (!col.IsComputed)
if (col.Generator == null)
{
memberInfo = userType.GetField(col.MemberName);

Expand Down Expand Up @@ -277,9 +276,9 @@ public static InternalSchemaDefinition Create(Type userType, SchemaDefinition us
colType = col.ColumnType;
}

dstCols[i] = col.IsComputed ?
new Column(colName, colType, col.Generator, col.Annotations)
: new Column(colName, colType, memberInfo, col.Annotations);
dstCols[i] = col.Generator != null ?
new Column(colName, colType, col.Generator, col.AnnotationInfos)
: new Column(colName, colType, memberInfo, col.AnnotationInfos);

}
return new InternalSchemaDefinition(dstCols);
Expand Down
114 changes: 114 additions & 0 deletions test/Microsoft.ML.Functional.Tests/SchemaDefinitionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.Functional.Tests
{
public class SchemaDefinitionTests : BaseTestClass
{
private MLContext _ml;

public SchemaDefinitionTests(ITestOutputHelper output) : base(output)
{
}

protected override void Initialize()
{
base.Initialize();

_ml = new MLContext(42);
_ml.AddStandardComponents();
Copy link
Member

@eerhardt eerhardt Mar 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to call AddStandardComponents? That should only be necessary when you are doing things like using the MAML syntax. When you are strictly using the API, it shouldn't be necessary.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this line back, because of issue #2996.


In reply to: 266474155 [](ancestors = 266474155)

}

[Fact]
public void SchemaDefinitionForPredictionEngine()
{
var fileName = GetDataPath(TestDatasets.adult.trainFilename);
var loader = _ml.Data.CreateTextLoader(new TextLoader.Options(), new MultiFileSource(fileName));
var data = loader.Load(new MultiFileSource(fileName));
var pipeline1 = _ml.Transforms.Categorical.OneHotEncoding("Cat", "Workclass", maximumNumberOfKeys: 3)
.Append(_ml.Transforms.Concatenate("Features", "Cat", "NumericFeatures"));
var model1 = pipeline1.Fit(data);

var pipeline2 = _ml.Transforms.Categorical.OneHotEncoding("Cat", "Workclass", maximumNumberOfKeys: 4)
.Append(_ml.Transforms.Concatenate("Features", "Cat", "NumericFeatures"));
var model2 = pipeline2.Fit(data);

var outputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData));
outputSchemaDefinition["Features"].ColumnType = model1.GetOutputSchema(data.Schema)["Features"].Type;
var engine1 = _ml.Model.CreatePredictionEngine<InputData, OutputData>(model1, outputSchemaDefinition: outputSchemaDefinition);

outputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData));
outputSchemaDefinition["Features"].ColumnType = model2.GetOutputSchema(data.Schema)["Features"].Type;
var engine2 = _ml.Model.CreatePredictionEngine<InputData, OutputData>(model2, outputSchemaDefinition: outputSchemaDefinition);

var prediction = engine1.Predict(new InputData() { Workclass = "Self-emp-not-inc", NumericFeatures = new float[6] });
Assert.Equal((engine1.OutputSchema["Features"].Type as VectorType).Size, prediction.Features.Length);
Assert.True(prediction.Features.All(x => x == 0));
prediction = engine2.Predict(new InputData() { Workclass = "Self-emp-not-inc", NumericFeatures = new float[6] });
Assert.Equal((engine2.OutputSchema["Features"].Type as VectorType).Size, prediction.Features.Length);
Assert.True(prediction.Features.Select((x, i) => i == 3 && x == 1 || x == 0).All(b => b));
}

[Fact]
public void SchemaDefinitionForCustomMapping()
{
var fileName = GetDataPath(TestDatasets.adult.trainFilename);
var data = new MultiFileSource(fileName);
var loader = _ml.Data.CreateTextLoader(new TextLoader.Options(), new MultiFileSource(fileName));
var pipeline = _ml.Transforms.Categorical.OneHotEncoding("Categories")
.Append(_ml.Transforms.Categorical.OneHotEncoding("Workclass"))
.Append(_ml.Transforms.Concatenate("Features", "NumericFeatures", "Categories", "Workclass"))
.Append(_ml.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("Features"));
var model = pipeline.Fit(loader.Load(data));
var schema = model.GetOutputSchema(loader.GetOutputSchema());

var inputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData));
inputSchemaDefinition["Features"].ColumnType = schema["Features"].Type;
var outputSchemaDefinition = SchemaDefinition.Create(typeof(OutputData));
outputSchemaDefinition["Features"].ColumnType = new VectorType(NumberDataViewType.Single, (schema["Features"].Type as VectorType).Size * 2);

var custom = _ml.Transforms.CustomMapping(
(OutputData src, OutputData dst) =>
{
dst.Features = new float[src.Features.Length * 2];
for (int i = 0; i < src.Features.Length; i++)
{
dst.Features[2 * i] = src.Features[i];
dst.Features[2 * i + 1] = (float)Math.Log(src.Features[i]);
}
}, null, inputSchemaDefinition, outputSchemaDefinition);

model = model.Append(custom.Fit(model.Transform(loader.Load(data))) as ITransformer);
schema = model.GetOutputSchema(loader.GetOutputSchema());
Assert.Equal(168, (schema["Features"].Type as VectorType).Size);
}

private sealed class InputData
{
[LoadColumn(0)]
public float Label { get; set; }
[LoadColumn(1)]
public string Workclass { get; set; }
[LoadColumn(2, 8)]
public string[] Categories { get; set; }
[LoadColumn(9, 14)]
[VectorType(6)]
public float[] NumericFeatures { get; set; }
}

private sealed class OutputData
{
public float Label { get; set; }
public float[] Features { get; set; }
}
}
}
24 changes: 12 additions & 12 deletions test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -200,32 +200,32 @@ public void MetadataSupportInDataViewConstruction()

// Create Metadata.
var kindFloat = "Testing float as metadata.";
var valueFloat = 10;
float valueFloat = 10;
var coltypeFloat = NumberDataViewType.Single;
var kindString = "Testing string as metadata.";
var valueString = "Strings have value.";
var coltypeString = TextDataViewType.Instance;
var kindStringArray = "Testing string array as metadata.";
var valueStringArray = "I really have no idea what these features entail.".Split(' ');
var coltypeStringArray = new VectorType(coltypeString, valueStringArray.Length);
var kindFloatArray = "Testing float array as metadata.";
var valueFloatArray = new float[] { 1, 17, 7, 19, 25, 0 };
var coltypeFloatArray = new VectorType(coltypeFloat, valueFloatArray.Length);
var kindVBuffer = "Testing VBuffer as metadata.";
var valueVBuffer = new VBuffer<float>(4, new float[] { 4, 6, 89, 5 });

var metaFloat = new AnnotationInfo<float>(kindFloat, valueFloat, coltypeFloat);
var metaString = new AnnotationInfo<string>(kindString, valueString);
var coltypeVBuffer = new VectorType(coltypeFloat, valueVBuffer.Length);

// Add Metadata.
var labelColumn = autoSchema[0];
var labelColumnWithMetadata = new SchemaDefinition.Column(mlContext, labelColumn.MemberName, labelColumn.ColumnType,
annotationInfos: new AnnotationInfo[] { metaFloat, metaString });
labelColumn.AddAnnotation(kindFloat, valueFloat, coltypeFloat);
labelColumn.AddAnnotation(kindString, valueString, coltypeString);

var featureColumnWithMetadata = autoSchema[1];
featureColumnWithMetadata.AddAnnotation(kindStringArray, valueStringArray);
featureColumnWithMetadata.AddAnnotation(kindFloatArray, valueFloatArray);
featureColumnWithMetadata.AddAnnotation(kindVBuffer, valueVBuffer);
var featureColumn = autoSchema[1];
featureColumn.AddAnnotation(kindStringArray, valueStringArray, coltypeStringArray);
featureColumn.AddAnnotation(kindFloatArray, valueFloatArray, coltypeFloatArray);
featureColumn.AddAnnotation(kindVBuffer, valueVBuffer, coltypeVBuffer);

var mySchema = new SchemaDefinition { labelColumnWithMetadata, featureColumnWithMetadata };
var idv = mlContext.Data.LoadFromEnumerable(data, mySchema);
var idv = mlContext.Data.LoadFromEnumerable(data, autoSchema);

Assert.True(idv.Schema[0].Annotations.Schema.Count == 2);
Assert.True(idv.Schema[0].Annotations.Schema[0].Name == kindFloat);
Expand Down