Skip to content

Commit

Permalink
* Make extension points abstract classes.
Browse files Browse the repository at this point in the history
* Improve ML.NET version handling to actually rely on the assembly vs. a constant string we have to change.
* Tighten checks of arguments to the context implementation.
  • Loading branch information
TomFinley committed Jul 3, 2018
1 parent 0bcdbc7 commit 4851898
Show file tree
Hide file tree
Showing 27 changed files with 130 additions and 113 deletions.
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
/// Save as ONNX.
/// </summary>
/// <param name="ctx">The ONNX program being built</param>
void SaveAsOnnx(IOnnxContext ctx);
void SaveAsOnnx(OnnxContext ctx);
}

/// <summary>
Expand All @@ -52,7 +52,7 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper
/// the outputs produced by this bindable mapper. This is the array that holds
/// those names, so that implementors of this method know what to produce in
/// <paramref name="ctx"/>.</param>
bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames);
bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames);
}

/// <summary>
Expand All @@ -61,7 +61,7 @@ public interface IBindableCanSaveOnnx : ICanSaveOnnx, ISchemaBindableMapper
/// </summary>
public interface ISingleCanSaveOnnx : ICanSaveOnnx
{
bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn);
bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn);
}

/// <summary>
Expand All @@ -70,6 +70,6 @@ public interface ISingleCanSaveOnnx : ICanSaveOnnx
/// </summary>
public interface IDistCanSaveOnnx : ISingleCanSaveOnnx, IValueMapperDist
{
new bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumn);
new bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn);
}
}
34 changes: 0 additions & 34 deletions src/Microsoft.ML.Data/Model/Onnx/IOnnxNode.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ namespace Microsoft.ML.Runtime.Model.Onnx
///
///
/// </summary>
public interface IOnnxContext
public abstract class OnnxContext
{
/// <summary>
/// Generates a unique name for the node based on a prefix.
/// </summary>
/// <param name="prefix">The prefix for the node</param>
/// <returns>A name that has not yet been returned from this function, starting with <paramref name="prefix"/></returns>
string GetNodeName(string prefix);
public abstract string GetNodeName(string prefix);

/// <summary>
/// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
/// safely call <see cref="GetVariableName(string)"/>.
/// </summary>
/// <param name="colName">The data view column name</param>
/// <returns>Whether the column is mapped in this context</returns>
bool ContainsColumn(string colName);
public abstract bool ContainsColumn(string colName);

/// <summary>
/// Stops tracking a column.
Expand All @@ -40,7 +40,7 @@ public interface IOnnxContext
/// <param name="removeVariable">Remove associated ONNX variable. This is useful in the event where an output
/// variable is created through <see cref="AddIntermediateVariable(ColumnType, string, bool)"/>before realizing
/// the transform cannot actually save as ONNX.</param>
void RemoveColumn(string colName, bool removeVariable = false);
public abstract void RemoveColumn(string colName, bool removeVariable = false);

/// <summary>
/// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the <see
Expand All @@ -49,7 +49,7 @@ public interface IOnnxContext
/// <param name="variableName">ONNX variable to remove. Note that this is an ONNX variable name, not an <see
/// cref="IDataView"/> column name</param>
/// <param name="removeColumn">IDataView column to stop tracking</param>
void RemoveVariable(string variableName, bool removeColumn);
public abstract void RemoveVariable(string variableName, bool removeColumn);

/// <summary>
/// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
Expand All @@ -61,7 +61,7 @@ public interface IOnnxContext
/// </summary>
/// <param name="colName">The data view column name</param>
/// <returns>The ONNX variable name corresponding to that data view column</returns>
string GetVariableName(string colName);
public abstract string GetVariableName(string colName);

/// <summary>
/// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
Expand All @@ -72,7 +72,7 @@ public interface IOnnxContext
/// <param name="skip">Whether we should skip the process of establishing the mapping from data view column to
/// ONNX variable name.</param>
/// <returns>The returned value is the name of the variable corresponding </returns>
string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);
public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);

/// <summary>
/// Creates an ONNX node
Expand All @@ -84,25 +84,25 @@ public interface IOnnxContext
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="GetNodeName(string)"/></param>
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
IOnnxNode CreateNode(string opType, IEnumerable<string> inputs,
public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
IEnumerable<string> outputs, string name, string domain = null);
}

public static class OnnxContextExtensions
{
/// <summary>
/// Convenience alternative to <see cref="IOnnxContext.CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>
/// Convenience alternative to <see cref="OnnxContext.CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>
/// for the case where there is exactly one input and output.
/// </summary>
/// <param name="ctx">The ONNX save context</param>
/// <param name="opType">The name of the ONNX operator to apply</param>
/// <param name="input">The name of the variable as input</param>
/// <param name="output">The name of the variable as output,
/// which ought to have been something returned from <see cref="IOnnxContext.AddIntermediateVariable(ColumnType, string, bool)"/></param>
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="IOnnxContext.GetNodeName(string)"/></param>
/// which ought to have been something returned from <see cref="OnnxContext.AddIntermediateVariable(ColumnType, string, bool)"/></param>
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="OnnxContext.GetNodeName(string)"/></param>
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
public static IOnnxNode CreateNode(this IOnnxContext ctx,
public static OnnxNode CreateNode(this OnnxContext ctx,
string opType, string input, string output, string name, string domain = null)
=> ctx.CreateNode(opType, new[] { input }, new[] { output }, name, domain);
}
Expand Down
32 changes: 32 additions & 0 deletions src/Microsoft.ML.Data/Model/Onnx/OnnxNode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime.Model.Onnx
{
/// <summary>
/// An abstraction for an ONNX node as created by
/// <see cref="OnnxContext.CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>.
/// That method creates a with inputs and outputs, but this object can modify the node further
/// by adding attributes (in ONNX parlance, attributes are more or less constant parameterizations).
/// </summary>
public abstract class OnnxNode
{
public abstract void AddAttribute(string argName, double value);
public abstract void AddAttribute(string argName, long value);
public abstract void AddAttribute(string argName, DvText value);
public abstract void AddAttribute(string argName, string value);
public abstract void AddAttribute(string argName, bool value);

public abstract void AddAttribute(string argName, IEnumerable<double> value);
public abstract void AddAttribute(string argName, IEnumerable<float> value);
public abstract void AddAttribute(string argName, IEnumerable<long> value);
public abstract void AddAttribute(string argName, IEnumerable<DvText> value);
public abstract void AddAttribute(string argName, string[] value);
public abstract void AddAttribute(string argName, IEnumerable<string> value);
public abstract void AddAttribute(string argName, IEnumerable<bool> value);
}
}
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ public void SaveAsPfa(BoundPfaContext ctx, JToken input,
probToken = ctx.DeclareVar(prob, probExpression);
}

public bool SaveAsOnnx(IOnnxContext ctx, string[] outputNames, string featureColumnName)
public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumnName)
{
Host.CheckValue(ctx, nameof(ctx));
Host.CheckValue(outputNames, nameof(outputNames));
Expand Down Expand Up @@ -658,7 +658,7 @@ public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] out
ctx.Hide(outputs);
}

public bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputs)
public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputs)
{
Host.CheckValue(ctx, nameof(ctx));
Host.CheckParam(Utils.Size(outputs) == 2, nameof(outputs), "Expected this to have two outputs");
Expand Down Expand Up @@ -1429,7 +1429,7 @@ public JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
PfaUtils.Call("+", -ParamB, PfaUtils.Call("*", -ParamA, input)));
}

public bool SaveAsOnnx(IOnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName)
public bool SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColumnNames, string featureColumnName)
{
_host.CheckValue(ctx, nameof(ctx));
_host.CheckValue(scoreProbablityColumnNames, nameof(scoreProbablityColumnNames));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ protected override void SaveCore(ModelSaveContext ctx)
ctx.Writer.Write(_threshold);
}

public override void SaveAsOnnx(IOnnxContext ctx)
public override void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(Bindable is IBindableCanSaveOnnx);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/GenericScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ public void SaveAsPfa(BoundPfaContext ctx)
pfaBindable.SaveAsPfa(ctx, schema, outColNames);
}

public void SaveAsOnnx(IOnnxContext ctx)
public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(Bindable is IBindableCanSaveOnnx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, string[] out
((IBindableCanSavePfa)_bindable).SaveAsPfa(ctx, schema, outputNames);
}

public bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
public bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(schema, nameof(schema));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Scorers/PredictedLabelScorerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ public void SaveAsPfa(BoundPfaContext ctx)

protected abstract JToken PredictedLabelPfa(string[] mapperOutputs);

public virtual void SaveAsOnnx(IOnnxContext ctx)
public virtual void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(Bindable is IBindableCanSaveOnnx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public virtual void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, stri
ctx.Hide(outputNames);
}

public virtual bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;

public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
{
Expand Down Expand Up @@ -289,7 +289,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str
ctx.DeclareVar(outputNames[0], scoreToken);
}

public override bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(schema, nameof(schema));
Expand Down Expand Up @@ -403,7 +403,7 @@ public override void SaveAsPfa(BoundPfaContext ctx, RoleMappedSchema schema, str
Contracts.Assert(ctx.TokenOrNullForName(outputNames[1]) == probToken.ToString());
}

public override bool SaveAsOnnx(IOnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(schema, nameof(schema));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/ConcatTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ public void SaveAsPfa(BoundPfaContext ctx)
ctx.DeclareVar(toDeclare.ToArray());
}

public void SaveAsOnnx(IOnnxContext ctx)
public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
Host.Assert(CanSaveOnnx);
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo
PfaUtils.Call("cast.fanoutDouble", -1, 0, keyCount, false), PfaUtils.FuncRef("u." + funcName));
}

protected override bool SaveAsOnnxCore(IOnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
{
string opType = "OneHotEncoder";
var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ private AffineColumnFunction(IHost host)

public abstract JToken PfaInfo(BoundPfaContext ctx, JToken srcToken);
public bool CanSaveOnnx => true;
public abstract bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount);
public abstract bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount);

public abstract Delegate GetGetter(IRow input, int icol);

Expand Down Expand Up @@ -548,7 +548,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)

public bool CanSaveOnnx => false;

public bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount)
public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
=> throw Host.ExceptNotSupp();

public abstract Delegate GetGetter(IRow input, int icol);
Expand Down Expand Up @@ -673,7 +673,7 @@ public JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)

public bool CanSaveOnnx => false;

public bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount)
public bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
=> throw Host.ExceptNotSupp();

public abstract Delegate GetGetter(IRow input, int icol);
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ public override void Save(ModelSaveContext ctx)
public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
=> PfaUtils.Call("*", PfaUtils.Call("-", srcToken, Offset), Scale);

public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode nodeProtoWrapper, int featureCount)
public override bool OnnxInfo(OnnxContext ctx, OnnxNode nodeProtoWrapper, int featureCount)
{
nodeProtoWrapper.AddAttribute("offset", Enumerable.Repeat(Offset, featureCount));
nodeProtoWrapper.AddAttribute("scale", Enumerable.Repeat(Scale, featureCount));
Expand Down Expand Up @@ -648,7 +648,7 @@ public override JToken PfaInfo(BoundPfaContext ctx, JToken srcToken)
return PfaUtils.Call("a.zipmap", srcToken, scaleCell, PfaUtils.FuncRef(ctx.Pfa.EnsureMul(itemType)));
}

public override bool OnnxInfo(IOnnxContext ctx, IOnnxNode node, int featureCount)
public override bool OnnxInfo(OnnxContext ctx, OnnxNode node, int featureCount)
{

if (Offset != null)
Expand Down
Loading

0 comments on commit 4851898

Please sign in to comment.