Skip to content

Commit

Permalink
Creation of components through MLContext and cleanup (text transform) (
Browse files Browse the repository at this point in the history
…#2394)

* text transform

* review comments

* review comments

* review comment on options
  • Loading branch information
abgoswam authored Feb 7, 2019
1 parent 26755ec commit 834e471
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 101 deletions.
12 changes: 6 additions & 6 deletions docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ public static void TextTransform()

// Another pipeline, that customizes the advanced settings of the FeaturizeText transformer.
string customizedColumnName = "CustomizedTextFeatures";
var customized_pipeline = ml.Transforms.Text.FeaturizeText(customizedColumnName, "SentimentText", s =>
{
s.KeepPunctuations = false;
s.KeepNumbers = false;
s.OutputTokens = true;
s.TextLanguage = TextFeaturizingEstimator.Language.English; // supports English, French, German, Dutch, Italian, Spanish, Japanese
var customized_pipeline = ml.Transforms.Text.FeaturizeText(customizedColumnName, new List<string> { "SentimentText" },
new TextFeaturizingEstimator.Options {
KeepPunctuations = false,
KeepNumbers = false,
OutputTokens = true,
TextLanguage = TextFeaturizingEstimator.Language.English, // supports English, French, German, Dutch, Italian, Spanish, Japanese
});

// The transformed data for both pipelines.
Expand Down
16 changes: 8 additions & 8 deletions src/Microsoft.ML.StaticPipe/TransformsStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1509,20 +1509,20 @@ internal sealed class OutPipelineColumn : Vector<float>
{
public readonly Scalar<string>[] Inputs;

public OutPipelineColumn(IEnumerable<Scalar<string>> inputs, Action<Settings> advancedSettings)
: base(new Reconciler(advancedSettings), inputs.ToArray())
public OutPipelineColumn(IEnumerable<Scalar<string>> inputs, Options options)
: base(new Reconciler(options), inputs.ToArray())
{
Inputs = inputs.ToArray();
}
}

private sealed class Reconciler : EstimatorReconciler
{
private readonly Action<Settings> _settings;
private readonly Options _settings;

public Reconciler(Action<Settings> advancedSettings)
public Reconciler(Options options)
{
_settings = advancedSettings;
_settings = options;
}

public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
Expand All @@ -1543,14 +1543,14 @@ public override IEstimator<ITransformer> Reconcile(IHostEnvironment env,
/// </summary>
/// <param name="input">Input data.</param>
/// <param name="otherInputs">Additional data.</param>
/// <param name="advancedSettings">Delegate which allows you to set transformation settings.</param>
/// <param name="options">Advanced transform settings.</param>
/// <returns></returns>
public static Vector<float> FeaturizeText(this Scalar<string> input, Scalar<string>[] otherInputs = null, Action<TextFeaturizingEstimator.Settings> advancedSettings = null)
public static Vector<float> FeaturizeText(this Scalar<string> input, Scalar<string>[] otherInputs = null, TextFeaturizingEstimator.Options options = null)
{
Contracts.CheckValue(input, nameof(input));
Contracts.CheckValueOrNull(otherInputs);
otherInputs = otherInputs ?? new Scalar<string>[0];
return new OutPipelineColumn(new[] { input }.Concat(otherInputs), advancedSettings);
return new OutPipelineColumn(new[] { input }.Concat(otherInputs), options);
}
}

Expand Down
26 changes: 12 additions & 14 deletions src/Microsoft.ML.Transforms/Text/TextCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,32 @@ public static class TextCatalog
/// <param name="catalog">The text-related transform's catalog.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="advancedSettings">Advanced transform settings</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeText](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs)]
/// ]]>
/// </format>
/// </example>
public static TextFeaturizingEstimator FeaturizeText(this TransformsCatalog.TextTransforms catalog,
string outputColumnName,
string inputColumnName = null,
Action<TextFeaturizingEstimator.Settings> advancedSettings = null)
string inputColumnName = null)
=> new TextFeaturizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnName, advancedSettings);
outputColumnName, inputColumnName);

/// <summary>
/// Transform several text columns into featurized float array that represents counts of ngrams and char-grams.
/// </summary>
/// <param name="catalog">The text-related transform's catalog.</param>
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnNames"/>.</param>
/// <param name="inputColumnNames">Name of the columns to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
/// <param name="advancedSettings">Advanced transform settings</param>
/// <param name="options">Advanced options to the algorithm.</param>
/// <example>
/// <format type="text/markdown">
/// <![CDATA[
/// [!code-csharp[FeaturizeText](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs)]
/// ]]>
/// </format>
/// </example>
public static TextFeaturizingEstimator FeaturizeText(this TransformsCatalog.TextTransforms catalog,
string outputColumnName,
IEnumerable<string> inputColumnNames,
Action<TextFeaturizingEstimator.Settings> advancedSettings = null)
TextFeaturizingEstimator.Options options)
=> new TextFeaturizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(),
outputColumnName, inputColumnNames, advancedSettings);
outputColumnName, inputColumnNames, options);

/// <summary>
/// Tokenize incoming text in <paramref name="inputColumnName"/> and output the tokens as <paramref name="outputColumnName"/>.
Expand Down
119 changes: 79 additions & 40 deletions src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,59 @@ internal sealed class Arguments : TransformInputBase
public TextNormKind VectorNormalizer = TextNormKind.L2;
}

public sealed class Settings
/// <summary>
/// Advanced options for the <see cref="TextFeaturizingEstimator"/>.
/// </summary>
public sealed class Options
{
#pragma warning disable MSML_NoInstanceInitializers // No initializers on instance fields or properties
/// <summary>
/// Dataset language.
/// </summary>
public Language TextLanguage { get; set; } = DefaultLanguage;
/// <summary>
/// Casing used for the text.
/// </summary>
public CaseNormalizationMode TextCase { get; set; } = CaseNormalizationMode.Lower;
/// <summary>
/// Whether to keep diacritical marks or remove them.
/// </summary>
public bool KeepDiacritics { get; set; } = false;
/// <summary>
/// Whether to keep punctuation marks or remove them.
/// </summary>
public bool KeepPunctuations { get; set; } = true;
/// <summary>
/// Whether to keep numbers or remove them.
/// </summary>
public bool KeepNumbers { get; set; } = true;
/// <summary>
/// Whether to output the transformed text tokens as an additional column.
/// </summary>
public bool OutputTokens { get; set; } = false;
/// <summary>
/// Vector Normalizer to use.
/// </summary>
public TextNormKind VectorNormalizer { get; set; } = TextNormKind.L2;
/// <summary>
/// Whether to use stop remover or not.
/// </summary>
public bool UseStopRemover { get; set; } = false;
/// <summary>
/// Whether to use char extractor or not.
/// </summary>
public bool UseCharExtractor { get; set; } = true;
/// <summary>
/// Whether to use word extractor or not.
/// </summary>
public bool UseWordExtractor { get; set; } = true;
#pragma warning restore MSML_NoInstanceInitializers // No initializers on instance fields or properties
}

public readonly string OutputColumn;
internal readonly string OutputColumn;
private readonly string[] _inputColumns;
public IReadOnlyCollection<string> InputColumns => _inputColumns.AsReadOnly();
public Settings AdvancedSettings { get; }
internal IReadOnlyCollection<string> InputColumns => _inputColumns.AsReadOnly();
internal Options OptionalSettings { get; }

// These parameters are hardcoded for now.
// REVIEW: expose them once sub-transforms are estimators.
Expand Down Expand Up @@ -232,18 +265,18 @@ public bool NeedInitialSourceColumnConcatTransform
public TransformApplierParams(TextFeaturizingEstimator parent)
{
var host = parent._host;
host.Check(Enum.IsDefined(typeof(Language), parent.AdvancedSettings.TextLanguage));
host.Check(Enum.IsDefined(typeof(CaseNormalizationMode), parent.AdvancedSettings.TextCase));
host.Check(Enum.IsDefined(typeof(Language), parent.OptionalSettings.TextLanguage));
host.Check(Enum.IsDefined(typeof(CaseNormalizationMode), parent.OptionalSettings.TextCase));
WordExtractorFactory = parent._wordFeatureExtractor?.CreateComponent(host, parent._dictionary);
CharExtractorFactory = parent._charFeatureExtractor?.CreateComponent(host, parent._dictionary);
VectorNormalizer = parent.AdvancedSettings.VectorNormalizer;
Language = parent.AdvancedSettings.TextLanguage;
UsePredefinedStopWordRemover = parent.AdvancedSettings.UseStopRemover;
TextCase = parent.AdvancedSettings.TextCase;
KeepDiacritics = parent.AdvancedSettings.KeepDiacritics;
KeepPunctuations = parent.AdvancedSettings.KeepPunctuations;
KeepNumbers = parent.AdvancedSettings.KeepNumbers;
OutputTextTokens = parent.AdvancedSettings.OutputTokens;
VectorNormalizer = parent.OptionalSettings.VectorNormalizer;
Language = parent.OptionalSettings.TextLanguage;
UsePredefinedStopWordRemover = parent.OptionalSettings.UseStopRemover;
TextCase = parent.OptionalSettings.TextCase;
KeepDiacritics = parent.OptionalSettings.KeepDiacritics;
KeepPunctuations = parent.OptionalSettings.KeepPunctuations;
KeepNumbers = parent.OptionalSettings.KeepNumbers;
OutputTextTokens = parent.OptionalSettings.OutputTokens;
Dictionary = parent._dictionary;
}
}
Expand All @@ -254,40 +287,42 @@ public TransformApplierParams(TextFeaturizingEstimator parent)
internal const string UserName = "Text Transform";
internal const string LoaderSignature = "Text";

public const Language DefaultLanguage = Language.English;
internal const Language DefaultLanguage = Language.English;

private const string TransformedTextColFormat = "{0}_TransformedText";

public TextFeaturizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null,
Action<Settings> advancedSettings = null)
: this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, advancedSettings)
internal TextFeaturizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null)
: this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName })
{
}

public TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable<string> source,
Action<Settings> advancedSettings = null)
internal TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable<string> source, Options options = null)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(TextFeaturizingEstimator));
_host.CheckValue(source, nameof(source));
_host.CheckParam(source.Any(), nameof(source));
_host.CheckParam(!source.Any(string.IsNullOrWhiteSpace), nameof(source));
_host.CheckNonEmpty(name, nameof(name));
_host.CheckValueOrNull(advancedSettings);
_host.CheckValueOrNull(options);

_inputColumns = source.ToArray();
OutputColumn = name;

AdvancedSettings = new Settings();
advancedSettings?.Invoke(AdvancedSettings);
OptionalSettings = new Options();
if (options != null)
OptionalSettings = options;

_dictionary = null;
if (AdvancedSettings.UseWordExtractor)
if (OptionalSettings.UseWordExtractor)
_wordFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments();
if (AdvancedSettings.UseCharExtractor)
if (OptionalSettings.UseCharExtractor)
_charFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false };
}

/// <summary>
/// Trains and returns a <see cref="Transformer"/>.
/// </summary>
public ITransformer Fit(IDataView input)
{
var h = _host;
Expand Down Expand Up @@ -463,14 +498,18 @@ public ITransformer Fit(IDataView input)
return new Transformer(_host, input, view);
}

public static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new Transformer(env, ctx);

private static string GenerateColumnName(Schema schema, string srcName, string xfTag)
{
return schema.GetTempColumnName(string.Format("{0}_{1}", srcName, xfTag));
}

/// <summary>
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
/// Used for schema propagation and verification in a pipeline.
/// </summary>
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));
Expand All @@ -485,12 +524,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)

var metadata = new List<SchemaShape.Column>(2);
metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
if (AdvancedSettings.VectorNormalizer != TextNormKind.None)
if (OptionalSettings.VectorNormalizer != TextNormKind.None)
metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false));

result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false,
new SchemaShape(metadata));
if (AdvancedSettings.OutputTokens)
if (OptionalSettings.OutputTokens)
{
string name = string.Format(TransformedTextColFormat, OutputColumn);
result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false);
Expand All @@ -502,18 +541,18 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
// Factory method for SignatureDataTransform.
internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data)
{
Action<Settings> settings = s =>
var settings = new Options
{
s.TextLanguage = args.Language;
s.TextCase = args.TextCase;
s.KeepDiacritics = args.KeepDiacritics;
s.KeepPunctuations = args.KeepPunctuations;
s.KeepNumbers = args.KeepNumbers;
s.OutputTokens = args.OutputTokens;
s.VectorNormalizer = args.VectorNormalizer;
s.UseStopRemover = args.UsePredefinedStopWordRemover;
s.UseWordExtractor = args.WordFeatureExtractor != null;
s.UseCharExtractor = args.CharFeatureExtractor != null;
TextLanguage = args.Language,
TextCase = args.TextCase,
KeepDiacritics = args.KeepDiacritics,
KeepPunctuations = args.KeepPunctuations,
KeepNumbers = args.KeepNumbers,
OutputTokens = args.OutputTokens,
VectorNormalizer = args.VectorNormalizer,
UseStopRemover = args.UsePredefinedStopWordRemover,
UseWordExtractor = args.WordFeatureExtractor != null,
UseCharExtractor = args.CharFeatureExtractor != null,
};

var estimator = new TextFeaturizingEstimator(env, args.Columns.Name, args.Columns.Source ?? new[] { args.Columns.Name }, settings);
Expand All @@ -530,7 +569,7 @@ private sealed class Transformer : ITransformer, ICanSaveModel
private readonly IHost _host;
private readonly IDataView _xf;

public Transformer(IHostEnvironment env, IDataView input, IDataView view)
internal Transformer(IHostEnvironment env, IDataView input, IDataView view)
{
_host = env.Register(nameof(Transformer));
_xf = ApplyTransformUtils.ApplyAllTransformsToData(_host, view, new EmptyDataView(_host, input.Schema), input);
Expand Down
10 changes: 5 additions & 5 deletions test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ public void SetupSentimentPipeline()

string _sentimentDataPath = BaseTestClass.GetDataPath("wikipedia-detox-250-line-data.tsv");

var env = new MLContext(seed: 1, conc: 1);
var reader = new TextLoader(env, columns: new[]
var mlContext = new MLContext(seed: 1, conc: 1);
var reader = new TextLoader(mlContext, columns: new[]
{
new TextLoader.Column("Label", DataKind.BL, 0),
new TextLoader.Column("SentimentText", DataKind.Text, 1)
Expand All @@ -83,13 +83,13 @@ public void SetupSentimentPipeline()

IDataView data = reader.Read(_sentimentDataPath);

var pipeline = new TextFeaturizingEstimator(env, "Features", "SentimentText")
.Append(env.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
new SdcaBinaryTrainer.Options {NumThreads = 1, ConvergenceTolerance = 1e-2f, }));

var model = pipeline.Fit(data);

_sentimentModel = model.CreatePredictionEngine<SentimentData, SentimentPrediction>(env);
_sentimentModel = model.CreatePredictionEngine<SentimentData, SentimentPrediction>(mlContext);
}

[GlobalSetup(Target = nameof(MakeBreastCancerPredictions))]
Expand Down
Loading

0 comments on commit 834e471

Please sign in to comment.