diff --git a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs index ca3aa075ab..70d8f898ab 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/PartitionedPathParser.cs @@ -76,7 +76,7 @@ public class Arguments : IPartitionedPathParserFactory { [Argument(ArgumentType.Multiple, HelpText = "Column definitions used to override the Partitioned Path Parser. Expected with the format name:type:numeric-source, e.g. col=MyFeature:R4:1", ShortName = "col", SortOrder = 1)] - public Microsoft.ML.Runtime.Data.PartitionedFileLoader.Column[] Columns; + public PartitionedFileLoader.Column[] Columns; [Argument(ArgumentType.AtMostOnce, HelpText = "Data type of each column.")] public DataKind Type = DataKind.Text; diff --git a/src/Microsoft.ML.PipelineInference/TransformInference.cs b/src/Microsoft.ML.PipelineInference/TransformInference.cs index 988b56eedf..6390139030 100644 --- a/src/Microsoft.ML.PipelineInference/TransformInference.cs +++ b/src/Microsoft.ML.PipelineInference/TransformInference.cs @@ -6,7 +6,6 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using Microsoft.ML; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -712,7 +711,7 @@ public override IEnumerable Apply(IntermediateColumn[] colum { Name = columnNameQuoted.ToString(), Source = columnNameQuoted.ToString(), - ResultType = ML.Transforms.DataKind.R4 + ResultType = ML.Data.DataKind.R4 }); } @@ -721,7 +720,7 @@ public override IEnumerable Apply(IntermediateColumn[] colum ch.Info("Suggested conversion to numeric for boolean features."); var args = new SubComponent("Convert", new[] { $"{columnArgument}type=R4" }); - var epInput = new ML.Transforms.ColumnTypeConverter { Column = epColumns.ToArray(), ResultType = ML.Transforms.DataKind.R4 }; + var epInput = new ML.Transforms.ColumnTypeConverter { Column = epColumns.ToArray(), ResultType = ML.Data.DataKind.R4 }; ColumnRoutingStructure.AnnotatedName[] columnsSource = epColumns.Select(c => new ColumnRoutingStructure.AnnotatedName { IsNumeric = false, Name = c.Name }).ToArray(); ColumnRoutingStructure.AnnotatedName[] columnsDest = diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index c103a96475..6009c5a1ae 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -1461,57 +1461,84 @@ public sealed class Output namespace Data { - public sealed partial class TextLoaderArguments + public enum DataKind : byte { - /// - /// Use separate parsing threads? - /// - public bool UseThreads { get; set; } = true; + I1 = 1, + U1 = 2, + I2 = 3, + U2 = 4, + I4 = 5, + U4 = 6, + I8 = 7, + U8 = 8, + R4 = 9, + Num = 9, + R8 = 10, + TX = 11, + Text = 11, + TXT = 11, + BL = 12, + Bool = 12, + TimeSpan = 13, + TS = 13, + DT = 14, + DateTime = 14, + DZ = 15, + DateTimeZone = 15, + UG = 16, + U16 = 16 + } + public sealed partial class TextLoaderRange + { /// - /// File containing a header with feature names. If specified, header defined in the data file (header+) is ignored. + /// First index in the range /// - public string HeaderFile { get; set; } + public int Min { get; set; } /// - /// Maximum number of rows to produce + /// Last index in the range /// - public long? MaxRows { get; set; } + public int? Max { get; set; } /// - /// Whether the input may include quoted values, which can contain separator characters, colons, and distinguish empty values from missing values. When true, consecutive separators denote a missing value and an empty value is denoted by "". When false, consecutive separators denote an empty value. + /// This range extends to the end of the line, but should be a fixed number of items /// - public bool AllowQuoting { get; set; } = true; + public bool AutoEnd { get; set; } = false; /// - /// Whether the input may include sparse representations + /// This range extends to the end of the line, which can vary from line to line /// - public bool AllowSparse { get; set; } = true; + public bool VariableEnd { get; set; } = false; /// - /// Number of source columns in the text data. Default is that sparse rows contain their size information. + /// This range includes only other indices not specified /// - public int? InputSize { get; set; } + public bool AllOther { get; set; } = false; /// - /// Source column separator. + /// Force scalar columns to be treated as vectors of length one /// - public char[] Separator { get; set; } = { '\t' }; + public bool ForceVector { get; set; } = false; + } + + public sealed partial class KeyRange + { /// - /// Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40 + /// First index in the range /// - public TextLoaderColumn[] Column { get; set; } + public ulong Min { get; set; } = 0; /// - /// Remove trailing whitespace from lines + /// Last index in the range /// - public bool TrimWhitespace { get; set; } = false; + public ulong? Max { get; set; } /// - /// Data file has header with feature names. Header is read only if options 'hs' and 'hf' are not specified. + /// Whether the key is contiguous /// - public bool HasHeader { get; set; } = false; + public bool Contiguous { get; set; } = true; } @@ -1539,56 +1566,57 @@ public sealed partial class TextLoaderColumn } - public sealed partial class TextLoaderRange + public sealed partial class TextLoaderArguments { /// - /// First index in the range + /// Use separate parsing threads? /// - public int Min { get; set; } + public bool UseThreads { get; set; } = true; /// - /// Last index in the range + /// File containing a header with feature names. If specified, header defined in the data file (header+) is ignored. /// - public int? Max { get; set; } + public string HeaderFile { get; set; } /// - /// This range extends to the end of the line, but should be a fixed number of items + /// Maximum number of rows to produce /// - public bool AutoEnd { get; set; } = false; + public long? MaxRows { get; set; } /// - /// This range extends to the end of the line, which can vary from line to line + /// Whether the input may include quoted values, which can contain separator characters, colons, and distinguish empty values from missing values. When true, consecutive separators denote a missing value and an empty value is denoted by "". When false, consecutive separators denote an empty value. /// - public bool VariableEnd { get; set; } = false; + public bool AllowQuoting { get; set; } = true; /// - /// This range includes only other indices not specified + /// Whether the input may include sparse representations /// - public bool AllOther { get; set; } = false; + public bool AllowSparse { get; set; } = true; /// - /// Force scalar columns to be treated as vectors of length one + /// Number of source columns in the text data. Default is that sparse rows contain their size information. /// - public bool ForceVector { get; set; } = false; + public int? InputSize { get; set; } - } + /// + /// Source column separator. + /// + public char[] Separator { get; set; } = { '\t' }; - public sealed partial class KeyRange - { /// - /// First index in the range + /// Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40 /// - public ulong Min { get; set; } = 0; + public TextLoaderColumn[] Column { get; set; } /// - /// Last index in the range + /// Remove trailing whitespace from lines /// - public ulong? Max { get; set; } + public bool TrimWhitespace { get; set; } = false; /// - /// Whether the key is contiguous + /// Data file has header with feature names. Header is read only if options 'hs' and 'hf' are not specified. /// - public bool Contiguous { get; set; } = true; + public bool HasHeader { get; set; } = false; } @@ -1640,7 +1668,7 @@ public TextLoaderPipelineStep (Output output) /// /// Arguments /// - public Microsoft.ML.Data.TextLoaderArguments Arguments { get; set; } = new Microsoft.ML.Data.TextLoaderArguments(); + public TextLoaderArguments Arguments { get; set; } = new TextLoaderArguments(); public sealed class Output @@ -1906,12 +1934,12 @@ public sealed partial class BinaryCrossValidator /// /// The training subgraph inputs /// - public Microsoft.ML.Models.CrossValidationBinaryMacroSubGraphInput Inputs { get; set; } = new Microsoft.ML.Models.CrossValidationBinaryMacroSubGraphInput(); + public CrossValidationBinaryMacroSubGraphInput Inputs { get; set; } = new CrossValidationBinaryMacroSubGraphInput(); /// /// The training subgraph outputs /// - public Microsoft.ML.Models.CrossValidationBinaryMacroSubGraphOutput Outputs { get; set; } = new Microsoft.ML.Models.CrossValidationBinaryMacroSubGraphOutput(); + public CrossValidationBinaryMacroSubGraphOutput Outputs { get; set; } = new CrossValidationBinaryMacroSubGraphOutput(); /// /// Column to use for stratification @@ -2178,7 +2206,7 @@ public sealed partial class CrossValidationResultsCombiner /// /// Specifies the trainer kind, which determines the evaluator to be used. /// - public Microsoft.ML.Models.MacroUtilsTrainerKinds Kind { get; set; } = Microsoft.ML.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + public MacroUtilsTrainerKinds Kind { get; set; } = MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; public sealed class Output @@ -2258,12 +2286,12 @@ public sealed partial class CrossValidator /// /// The training subgraph inputs /// - public Microsoft.ML.Models.CrossValidationMacroSubGraphInput Inputs { get; set; } = new Microsoft.ML.Models.CrossValidationMacroSubGraphInput(); + public CrossValidationMacroSubGraphInput Inputs { get; set; } = new CrossValidationMacroSubGraphInput(); /// /// The training subgraph outputs /// - public Microsoft.ML.Models.CrossValidationMacroSubGraphOutput Outputs { get; set; } = new Microsoft.ML.Models.CrossValidationMacroSubGraphOutput(); + public CrossValidationMacroSubGraphOutput Outputs { get; set; } = new CrossValidationMacroSubGraphOutput(); /// /// Column to use for stratification @@ -2278,7 +2306,7 @@ public sealed partial class CrossValidator /// /// Specifies the trainer kind, which determines the evaluator to be used. /// - public Microsoft.ML.Models.MacroUtilsTrainerKinds Kind { get; set; } = Microsoft.ML.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + public MacroUtilsTrainerKinds Kind { get; set; } = MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; /// /// Column to use for labels @@ -2687,7 +2715,7 @@ public sealed partial class OneVersusAll : Microsoft.ML.Runtime.EntryPoints.Comm /// /// The training subgraph output. /// - public Microsoft.ML.Models.OneVersusAllMacroSubGraphOutput OutputForSubGraph { get; set; } = new Microsoft.ML.Models.OneVersusAllMacroSubGraphOutput(); + public OneVersusAllMacroSubGraphOutput OutputForSubGraph { get; set; } = new OneVersusAllMacroSubGraphOutput(); /// /// Use probabilities in OVA combiner @@ -2717,12 +2745,12 @@ public sealed partial class OneVersusAll : Microsoft.ML.Runtime.EntryPoints.Comm /// /// Normalize option for the feature column /// - public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + public NormalizeOption NormalizeFeatures { get; set; } = NormalizeOption.Auto; /// /// Whether learner should cache input training data /// - public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + public CachingOptions Caching { get; set; } = CachingOptions.Auto; public sealed class Output @@ -2862,12 +2890,12 @@ public sealed partial class OvaModelCombiner : Microsoft.ML.Runtime.EntryPoints. /// /// Normalize option for the feature column /// - public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + public NormalizeOption NormalizeFeatures { get; set; } = NormalizeOption.Auto; /// /// Whether learner should cache input training data /// - public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + public CachingOptions Caching { get; set; } = CachingOptions.Auto; public sealed class Output @@ -3421,12 +3449,12 @@ public sealed partial class TrainTestBinaryEvaluator /// /// The training subgraph inputs /// - public Microsoft.ML.Models.TrainTestBinaryMacroSubGraphInput Inputs { get; set; } = new Microsoft.ML.Models.TrainTestBinaryMacroSubGraphInput(); + public TrainTestBinaryMacroSubGraphInput Inputs { get; set; } = new TrainTestBinaryMacroSubGraphInput(); /// /// The training subgraph outputs /// - public Microsoft.ML.Models.TrainTestBinaryMacroSubGraphOutput Outputs { get; set; } = new Microsoft.ML.Models.TrainTestBinaryMacroSubGraphOutput(); + public TrainTestBinaryMacroSubGraphOutput Outputs { get; set; } = new TrainTestBinaryMacroSubGraphOutput(); public sealed class Output @@ -3516,17 +3544,17 @@ public sealed partial class TrainTestEvaluator /// /// The training subgraph inputs /// - public Microsoft.ML.Models.TrainTestMacroSubGraphInput Inputs { get; set; } = new Microsoft.ML.Models.TrainTestMacroSubGraphInput(); + public TrainTestMacroSubGraphInput Inputs { get; set; } = new TrainTestMacroSubGraphInput(); /// /// The training subgraph outputs /// - public Microsoft.ML.Models.TrainTestMacroSubGraphOutput Outputs { get; set; } = new Microsoft.ML.Models.TrainTestMacroSubGraphOutput(); + public TrainTestMacroSubGraphOutput Outputs { get; set; } = new TrainTestMacroSubGraphOutput(); /// /// Specifies the trainer kind, which determines the evaluator to be used. /// - public Microsoft.ML.Models.MacroUtilsTrainerKinds Kind { get; set; } = Microsoft.ML.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + public MacroUtilsTrainerKinds Kind { get; set; } = MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; /// /// Identifies which pipeline was run for this train test. @@ -3888,7 +3916,7 @@ public sealed partial class FastForestBinaryClassifier : Microsoft.ML.Runtime.En /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Bundle Bundling { get; set; } = Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -4170,7 +4198,7 @@ public sealed partial class FastForestRegressor : Microsoft.ML.Runtime.EntryPoin /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Bundle Bundling { get; set; } = Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -4403,7 +4431,7 @@ public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.Runtime.Entr /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -4568,7 +4596,7 @@ public sealed partial class FastTreeBinaryClassifier : Microsoft.ML.Runtime.Entr /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Bundle Bundling { get; set; } = Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -4829,7 +4857,7 @@ public sealed partial class FastTreeRanker : Microsoft.ML.Runtime.EntryPoints.Co /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -4994,7 +5022,7 @@ public sealed partial class FastTreeRanker : Microsoft.ML.Runtime.EntryPoints.Co /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Bundle Bundling { get; set; } = Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -5215,7 +5243,7 @@ public sealed partial class FastTreeRegressor : Microsoft.ML.Runtime.EntryPoints /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -5380,7 +5408,7 @@ public sealed partial class FastTreeRegressor : Microsoft.ML.Runtime.EntryPoints /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Bundle Bundling { get; set; } = Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -5606,7 +5634,7 @@ public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.Runtime.Entr /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -5771,7 +5799,7 @@ public sealed partial class FastTreeTweedieRegressor : Microsoft.ML.Runtime.Entr /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Bundle Bundling { get; set; } = Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -6283,7 +6311,7 @@ public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.Entry /// /// Cluster initialization algorithm /// - public Microsoft.ML.Trainers.KMeansPlusPlusTrainerInitAlgorithm InitAlgorithm { get; set; } = Microsoft.ML.Trainers.KMeansPlusPlusTrainerInitAlgorithm.KMeansParallel; + public KMeansPlusPlusTrainerInitAlgorithm InitAlgorithm { get; set; } = KMeansPlusPlusTrainerInitAlgorithm.KMeansParallel; /// /// Tolerance parameter for trainer convergence. Lower = slower, more accurate @@ -8029,7 +8057,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NormalizeTransformBinColumn[] Column { get; set; } + public NormalizeTransformBinColumn[] Column { get; set; } /// /// Max number of bins, power of 2 recommended @@ -8132,7 +8160,7 @@ public sealed partial class CategoricalHashTransformColumn : OneToOneColumn /// Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index) /// - public Microsoft.ML.Transforms.CategoricalTransformOutputKind? OutputKind { get; set; } + public CategoricalTransformOutputKind? OutputKind { get; set; } /// /// Name of the new column @@ -8196,7 +8224,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:hashBits:src) /// - public Microsoft.ML.Transforms.CategoricalHashTransformColumn[] Column { get; set; } + public CategoricalHashTransformColumn[] Column { get; set; } /// /// Number of bits to hash into. Must be between 1 and 30, inclusive. @@ -8221,7 +8249,7 @@ public void AddColumn(string name, string source) /// /// Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index) /// - public Microsoft.ML.Transforms.CategoricalTransformOutputKind OutputKind { get; set; } = Microsoft.ML.Transforms.CategoricalTransformOutputKind.Bag; + public CategoricalTransformOutputKind OutputKind { get; set; } = CategoricalTransformOutputKind.Bag; /// /// Input dataset @@ -8287,7 +8315,7 @@ public sealed partial class CategoricalTransformColumn : OneToOneColumn /// Output kind: Bag (multi-set vector), Ind (indicator vector), Key (index), or Binary encoded indicator vector /// - public Microsoft.ML.Transforms.CategoricalTransformOutputKind? OutputKind { get; set; } + public CategoricalTransformOutputKind? OutputKind { get; set; } /// /// Maximum number of terms to keep when auto-training @@ -8302,7 +8330,7 @@ public sealed partial class CategoricalTransformColumn : OneToOneColumn /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). /// - public Microsoft.ML.Transforms.TermTransformSortOrder? Sort { get; set; } + public TermTransformSortOrder? Sort { get; set; } /// /// Whether key value metadata should be text, regardless of the actual input type @@ -8371,12 +8399,12 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.CategoricalTransformColumn[] Column { get; set; } + public CategoricalTransformColumn[] Column { get; set; } /// /// Output kind: Bag (multi-set vector), Ind (indicator vector), or Key (index) /// - public Microsoft.ML.Transforms.CategoricalTransformOutputKind OutputKind { get; set; } = Microsoft.ML.Transforms.CategoricalTransformOutputKind.Ind; + public CategoricalTransformOutputKind OutputKind { get; set; } = CategoricalTransformOutputKind.Ind; /// /// Maximum number of terms to keep per column when auto-training @@ -8391,7 +8419,7 @@ public void AddColumn(string name, string source) /// /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). /// - public Microsoft.ML.Transforms.TermTransformSortOrder Sort { get; set; } = Microsoft.ML.Transforms.TermTransformSortOrder.Occurrence; + public TermTransformSortOrder Sort { get; set; } = TermTransformSortOrder.Occurrence; /// /// Whether key value metadata should be text, regardless of the actual input type @@ -8515,7 +8543,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.CharTokenizeTransformColumn[] Column { get; set; } + public CharTokenizeTransformColumn[] Column { get; set; } /// /// Whether to mark the beginning/end of each row/slot with start of text character (0x02)/end of text character (0x03) @@ -8615,7 +8643,7 @@ public void AddColumn(string name, params string[] source) /// /// New column definition(s) (optional form: name:srcs) /// - public Microsoft.ML.Transforms.ConcatTransformColumn[] Column { get; set; } + public ConcatTransformColumn[] Column { get; set; } /// /// Input dataset @@ -8734,7 +8762,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.CopyColumnsTransformColumn[] Column { get; set; } + public CopyColumnsTransformColumn[] Column { get; set; } /// /// Input dataset @@ -8918,41 +8946,13 @@ public ColumnSelectorPipelineStep(Output output) namespace Transforms { - public enum DataKind : byte - { - I1 = 1, - U1 = 2, - I2 = 3, - U2 = 4, - I4 = 5, - U4 = 6, - I8 = 7, - U8 = 8, - R4 = 9, - Num = 9, - R8 = 10, - TX = 11, - Text = 11, - TXT = 11, - BL = 12, - Bool = 12, - TimeSpan = 13, - TS = 13, - DT = 14, - DateTime = 14, - DZ = 15, - DateTimeZone = 15, - UG = 16, - U16 = 16 - } - public sealed partial class ConvertTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The result type /// - public Microsoft.ML.Transforms.DataKind? ResultType { get; set; } + public Microsoft.ML.Data.DataKind? ResultType { get; set; } /// /// For a key column, this defines the range of values @@ -9021,12 +9021,12 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:type:src) /// - public Microsoft.ML.Transforms.ConvertTransformColumn[] Column { get; set; } + public ConvertTransformColumn[] Column { get; set; } /// /// The result type /// - public Microsoft.ML.Transforms.DataKind? ResultType { get; set; } + public Microsoft.ML.Data.DataKind? ResultType { get; set; } /// /// For a key column, this defines the range of values @@ -9230,7 +9230,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NormalizeTransformAffineColumn[] Column { get; set; } + public NormalizeTransformAffineColumn[] Column { get; set; } /// /// Whether to map zero to zero, preserving sparsity @@ -9311,7 +9311,7 @@ public sealed partial class DataCache : Microsoft.ML.Runtime.EntryPoints.CommonI /// /// Caching strategy /// - public Microsoft.ML.Transforms.CacheCachingType Caching { get; set; } = Microsoft.ML.Transforms.CacheCachingType.Memory; + public CacheCachingType Caching { get; set; } = CacheCachingType.Memory; /// /// Input dataset @@ -9454,7 +9454,7 @@ public sealed partial class TermTransformColumn : OneToOneColumn /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). /// - public Microsoft.ML.Transforms.TermTransformSortOrder? Sort { get; set; } + public TermTransformSortOrder? Sort { get; set; } /// /// Whether key value metadata should be text, regardless of the actual input type @@ -9523,7 +9523,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.TermTransformColumn[] Column { get; set; } + public TermTransformColumn[] Column { get; set; } /// /// Maximum number of terms to keep per column when auto-training @@ -9538,7 +9538,7 @@ public void AddColumn(string name, string source) /// /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). /// - public Microsoft.ML.Transforms.TermTransformSortOrder Sort { get; set; } = Microsoft.ML.Transforms.TermTransformSortOrder.Occurrence; + public TermTransformSortOrder Sort { get; set; } = TermTransformSortOrder.Occurrence; /// /// Whether key value metadata should be text, regardless of the actual input type @@ -9892,7 +9892,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.LpNormNormalizerTransformGcnColumn[] Column { get; set; } + public LpNormNormalizerTransformGcnColumn[] Column { get; set; } /// /// Subtract mean from each value before normalizing @@ -10051,7 +10051,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.HashJoinTransformColumn[] Column { get; set; } + public HashJoinTransformColumn[] Column { get; set; } /// /// Whether the values need to be combined for a single hash @@ -10190,7 +10190,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.KeyToValueTransformColumn[] Column { get; set; } + public KeyToValueTransformColumn[] Column { get; set; } /// /// Input dataset @@ -10384,7 +10384,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.LabelIndicatorTransformColumn[] Column { get; set; } + public LabelIndicatorTransformColumn[] Column { get; set; } /// /// Label of the positive class. @@ -10583,7 +10583,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NormalizeTransformLogNormalColumn[] Column { get; set; } + public NormalizeTransformLogNormalColumn[] Column { get; set; } /// /// Max number of examples used to train the normalizer @@ -10656,7 +10656,7 @@ public sealed partial class LpNormNormalizerTransformColumn : OneToOneColumn /// The norm to use to normalize each sample /// - public Microsoft.ML.Transforms.LpNormNormalizerTransformNormalizerKind? NormKind { get; set; } + public LpNormNormalizerTransformNormalizerKind? NormKind { get; set; } /// /// Subtract mean from each value before normalizing @@ -10725,12 +10725,12 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.LpNormNormalizerTransformColumn[] Column { get; set; } + public LpNormNormalizerTransformColumn[] Column { get; set; } /// /// The norm to use to normalize each sample /// - public Microsoft.ML.Transforms.LpNormNormalizerTransformNormalizerKind NormKind { get; set; } = Microsoft.ML.Transforms.LpNormNormalizerTransformNormalizerKind.L2Norm; + public LpNormNormalizerTransformNormalizerKind NormKind { get; set; } = LpNormNormalizerTransformNormalizerKind.L2Norm; /// /// Subtract mean from each value before normalizing @@ -10877,7 +10877,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NormalizeTransformAffineColumn[] Column { get; set; } + public NormalizeTransformAffineColumn[] Column { get; set; } /// /// Whether to map zero to zero, preserving sparsity @@ -10992,7 +10992,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NormalizeTransformAffineColumn[] Column { get; set; } + public NormalizeTransformAffineColumn[] Column { get; set; } /// /// Whether to map zero to zero, preserving sparsity @@ -11074,7 +11074,7 @@ public sealed partial class NAHandleTransformColumn : OneToOneColumn /// The replacement method to utilize /// - public Microsoft.ML.Transforms.NAHandleTransformReplacementKind? Kind { get; set; } + public NAHandleTransformReplacementKind? Kind { get; set; } /// /// Whether to impute values by slot @@ -11148,12 +11148,12 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:rep:src) /// - public Microsoft.ML.Transforms.NAHandleTransformColumn[] Column { get; set; } + public NAHandleTransformColumn[] Column { get; set; } /// /// The replacement method to utilize /// - public Microsoft.ML.Transforms.NAHandleTransformReplacementKind ReplaceWith { get; set; } = Microsoft.ML.Transforms.NAHandleTransformReplacementKind.Def; + public NAHandleTransformReplacementKind ReplaceWith { get; set; } = NAHandleTransformReplacementKind.Def; /// /// Whether to impute values by slot @@ -11282,7 +11282,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NAIndicatorTransformColumn[] Column { get; set; } + public NAIndicatorTransformColumn[] Column { get; set; } /// /// Input dataset @@ -11401,7 +11401,7 @@ public void AddColumn(string name, string source) /// /// Columns to drop the NAs for /// - public Microsoft.ML.Transforms.NADropTransformColumn[] Column { get; set; } + public NADropTransformColumn[] Column { get; set; } /// /// Input dataset @@ -11551,7 +11551,7 @@ public sealed partial class NAReplaceTransformColumn : OneToOneColumn /// The replacement method to utilize /// - public Microsoft.ML.Transforms.NAReplaceTransformReplacementKind? Kind { get; set; } + public NAReplaceTransformReplacementKind? Kind { get; set; } /// /// Whether to impute values by slot @@ -11620,12 +11620,12 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:rep:src) /// - public Microsoft.ML.Transforms.NAReplaceTransformColumn[] Column { get; set; } + public NAReplaceTransformColumn[] Column { get; set; } /// /// The replacement method to utilize /// - public Microsoft.ML.Transforms.NAReplaceTransformReplacementKind ReplacementKind { get; set; } = Microsoft.ML.Transforms.NAReplaceTransformReplacementKind.Def; + public NAReplaceTransformReplacementKind ReplacementKind { get; set; } = NAReplaceTransformReplacementKind.Def; /// /// Whether to impute values by slot @@ -11744,7 +11744,7 @@ public sealed partial class NgramTransformColumn : OneToOneColumn /// Statistical measure used to evaluate how important a word is to a document in a corpus /// - public Microsoft.ML.Transforms.NgramTransformWeightingCriteria? Weighting { get; set; } + public NgramTransformWeightingCriteria? Weighting { get; set; } /// /// Name of the new column @@ -11808,7 +11808,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NgramTransformColumn[] Column { get; set; } + public NgramTransformColumn[] Column { get; set; } /// /// Maximum ngram length @@ -11833,7 +11833,7 @@ public void AddColumn(string name, string source) /// /// The weighting criteria /// - public Microsoft.ML.Transforms.NgramTransformWeightingCriteria Weighting { get; set; } = Microsoft.ML.Transforms.NgramTransformWeightingCriteria.Tf; + public NgramTransformWeightingCriteria Weighting { get; set; } = NgramTransformWeightingCriteria.Tf; /// /// Input dataset @@ -12102,7 +12102,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.PcaTransformColumn[] Column { get; set; } + public PcaTransformColumn[] Column { get; set; } /// /// The name of the weight column @@ -12276,7 +12276,7 @@ public sealed partial class RandomNumberGenerator : Microsoft.ML.Runtime.EntryPo /// /// New column definition(s) (optional form: name:seed) /// - public Microsoft.ML.Transforms.GenerateNumberTransformColumn[] Column { get; set; } + public GenerateNumberTransformColumn[] Column { get; set; } /// /// Use an auto-incremented integer starting at zero instead of a random number @@ -12750,7 +12750,7 @@ public sealed partial class Segregator : Microsoft.ML.Runtime.EntryPoints.Common /// /// Specifies how to unroll multiple pivot columns of different size. /// - public Microsoft.ML.Transforms.UngroupTransformUngroupMode Mode { get; set; } = Microsoft.ML.Transforms.UngroupTransformUngroupMode.Inner; + public UngroupTransformUngroupMode Mode { get; set; } = UngroupTransformUngroupMode.Inner; /// /// Input dataset @@ -12935,7 +12935,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.NormalizeTransformBinColumn[] Column { get; set; } + public NormalizeTransformBinColumn[] Column { get; set; } /// /// Max number of bins, power of 2 recommended @@ -13055,7 +13055,7 @@ public sealed partial class TermLoaderArguments /// /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). /// - public Microsoft.ML.Transforms.TermTransformSortOrder Sort { get; set; } = Microsoft.ML.Transforms.TermTransformSortOrder.Occurrence; + public TermTransformSortOrder Sort { get; set; } = TermTransformSortOrder.Occurrence; /// /// Drop unknown terms instead of mapping them to NA term. @@ -13088,12 +13088,12 @@ public void AddColumn(string name, params string[] source) /// /// New column definition (optional form: name:srcs). /// - public Microsoft.ML.Transforms.TextTransformColumn Column { get; set; } + public TextTransformColumn Column { get; set; } /// /// Dataset language or 'AutoDetect' to detect language per row. /// - public Microsoft.ML.Transforms.TextTransformLanguage Language { get; set; } = Microsoft.ML.Transforms.TextTransformLanguage.English; + public TextTransformLanguage Language { get; set; } = TextTransformLanguage.English; /// /// Stopwords remover. @@ -13104,7 +13104,7 @@ public void AddColumn(string name, params string[] source) /// /// Casing text using the rules of the invariant culture. /// - public Microsoft.ML.Transforms.TextNormalizerTransformCaseNormalizationMode TextCase { get; set; } = Microsoft.ML.Transforms.TextNormalizerTransformCaseNormalizationMode.Lower; + public TextNormalizerTransformCaseNormalizationMode TextCase { get; set; } = TextNormalizerTransformCaseNormalizationMode.Lower; /// /// Whether to keep diacritical marks or remove them. @@ -13129,7 +13129,7 @@ public void AddColumn(string name, params string[] source) /// /// A dictionary of whitelisted terms. /// - public Microsoft.ML.Transforms.TermLoaderArguments Dictionary { get; set; } + public TermLoaderArguments Dictionary { get; set; } /// /// Ngram feature extractor to use for words (WordBag/WordHashBag). @@ -13146,7 +13146,7 @@ public void AddColumn(string name, params string[] source) /// /// Normalize vectors (rows) individually by rescaling them to unit norm. /// - public Microsoft.ML.Transforms.TextTransformTextNormKind VectorNormalizer { get; set; } = Microsoft.ML.Transforms.TextTransformTextNormKind.L2; + public TextTransformTextNormKind VectorNormalizer { get; set; } = TextTransformTextNormKind.L2; /// /// Input dataset @@ -13251,7 +13251,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) (optional form: name:src) /// - public Microsoft.ML.Transforms.TermTransformColumn[] Column { get; set; } + public TermTransformColumn[] Column { get; set; } /// /// Maximum number of terms to keep per column when auto-training @@ -13266,7 +13266,7 @@ public void AddColumn(string name, string source) /// /// How items should be ordered when vectorized. By default, they will be in the order encountered. If by value items are sorted according to their default comparison, e.g., text sorting will be case sensitive (e.g., 'A' then 'Z' then 'a'). /// - public Microsoft.ML.Transforms.TermTransformSortOrder Sort { get; set; } = Microsoft.ML.Transforms.TermTransformSortOrder.Occurrence; + public TermTransformSortOrder Sort { get; set; } = TermTransformSortOrder.Occurrence; /// /// Whether key value metadata should be text, regardless of the actual input type @@ -13544,7 +13544,7 @@ public void AddColumn(string name, string source) /// /// New column definition(s) /// - public Microsoft.ML.Transforms.DelimitedTokenizeTransformColumn[] Column { get; set; } + public DelimitedTokenizeTransformColumn[] Column { get; set; } /// /// Comma separated set of term separator(s). Commonly: 'space', 'comma', 'semicolon' or other single character. @@ -13699,7 +13699,7 @@ public sealed class AutoMlStateAutoMlStateBase : AutoMlStateBase /// /// Supported metric for evaluator. /// - public Microsoft.ML.Runtime.AutoInferenceAutoMlMlStateArgumentsMetrics Metric { get; set; } = Microsoft.ML.Runtime.AutoInferenceAutoMlMlStateArgumentsMetrics.Auc; + public AutoInferenceAutoMlMlStateArgumentsMetrics Metric { get; set; } = AutoInferenceAutoMlMlStateArgumentsMetrics.Auc; /// /// AutoML engine (pipeline optimizer) that generates next candidates. @@ -13710,7 +13710,7 @@ public sealed class AutoMlStateAutoMlStateBase : AutoMlStateBase /// /// Kind of trainer for task, such as binary classification trainer, multiclass trainer, etc. /// - public Microsoft.ML.Models.MacroUtilsTrainerKinds TrainerKind { get; set; } = Microsoft.ML.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + public Microsoft.ML.Models.MacroUtilsTrainerKinds TrainerKind { get; set; } = Microsoft.ML.Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; /// /// Arguments for creating terminator, which determines when to stop search. @@ -13730,9 +13730,6 @@ public abstract class CalibratorTrainer : ComponentKind {} - /// - /// - /// public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer { /// @@ -13750,9 +13747,6 @@ public sealed class FixedPlattCalibratorCalibratorTrainer : CalibratorTrainer - /// - /// - /// public sealed class NaiveCalibratorCalibratorTrainer : CalibratorTrainer { internal override string ComponentName => "NaiveCalibrator"; @@ -13760,9 +13754,6 @@ public sealed class NaiveCalibratorCalibratorTrainer : CalibratorTrainer - /// - /// - /// public sealed class PavCalibratorCalibratorTrainer : CalibratorTrainer { internal override string ComponentName => "PavCalibrator"; @@ -13966,7 +13957,7 @@ public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTraine /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -14131,7 +14122,7 @@ public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTraine /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -14274,12 +14265,12 @@ public sealed class FastTreeBinaryClassificationFastTreeTrainer : FastTreeTraine /// /// Normalize option for the feature column /// - public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; /// /// Whether learner should cache input training data /// - public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; internal override string ComponentName => "FastTreeBinaryClassification"; } @@ -14354,7 +14345,7 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -14519,7 +14510,7 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -14662,12 +14653,12 @@ public sealed class FastTreeRankingFastTreeTrainer : FastTreeTrainer /// /// Normalize option for the feature column /// - public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; /// /// Whether learner should cache input training data /// - public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; internal override string ComponentName => "FastTreeRanking"; } @@ -14702,7 +14693,7 @@ public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -14867,7 +14858,7 @@ public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -15010,12 +15001,12 @@ public sealed class FastTreeRegressionFastTreeTrainer : FastTreeTrainer /// /// Normalize option for the feature column /// - public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; /// /// Whether learner should cache input training data /// - public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; internal override string ComponentName => "FastTreeRegression"; } @@ -15055,7 +15046,7 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer /// /// Optimization algorithm to be used (GradientDescent, AcceleratedGradientDescent) /// - public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; + public Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType OptimizationAlgorithm { get; set; } = Microsoft.ML.Trainers.BoostedTreeArgsOptimizationAlgorithmType.GradientDescent; /// /// Early stopping rule. (Validation set (/valid) is required.) @@ -15220,7 +15211,7 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer /// /// Bundle low population bins. Bundle.None(0): no bundling, Bundle.AggregateLowPopulation(1): Bundle low population, Bundle.Adjacent(2): Neighbor low population bundle. /// - public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; + public Microsoft.ML.Trainers.Bundle Bundling { get; set; } = Microsoft.ML.Trainers.Bundle.None; /// /// Maximum number of distinct values (bins) per feature @@ -15363,12 +15354,12 @@ public sealed class FastTreeTweedieRegressionFastTreeTrainer : FastTreeTrainer /// /// Normalize option for the feature column /// - public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; + public Microsoft.ML.Models.NormalizeOption NormalizeFeatures { get; set; } = Microsoft.ML.Models.NormalizeOption.Auto; /// /// Whether learner should cache input training data /// - public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; + public Microsoft.ML.Models.CachingOptions Caching { get; set; } = Microsoft.ML.Models.CachingOptions.Auto; internal override string ComponentName => "FastTreeTweedieRegression"; } @@ -15405,7 +15396,7 @@ public sealed class NGramNgramExtractor : NgramExtractor /// /// The weighting criteria /// - public Microsoft.ML.Transforms.NgramTransformWeightingCriteria Weighting { get; set; } = Microsoft.ML.Transforms.NgramTransformWeightingCriteria.Tf; + public Microsoft.ML.Transforms.NgramTransformWeightingCriteria Weighting { get; set; } = Microsoft.ML.Transforms.NgramTransformWeightingCriteria.Tf; internal override string ComponentName => "NGram"; } @@ -15490,7 +15481,7 @@ public sealed partial class PartitionedFileLoaderColumn /// /// Data type of the column. /// - public Microsoft.ML.Transforms.DataKind? Type { get; set; } + public Microsoft.ML.Data.DataKind? Type { get; set; } /// /// Index of the directory representing this column. @@ -15508,12 +15499,12 @@ public sealed class SimplePathParserPartitionedPathParser : PartitionedPathParse /// /// Column definitions used to override the Partitioned Path Parser. Expected with the format name:type:numeric-source, e.g. col=MyFeature:R4:1 /// - public Microsoft.ML.Runtime.PartitionedFileLoaderColumn[] Columns { get; set; } + public PartitionedFileLoaderColumn[] Columns { get; set; } /// /// Data type of each column. /// - public Microsoft.ML.Transforms.DataKind Type { get; set; } = Microsoft.ML.Transforms.DataKind.TX; + public Microsoft.ML.Data.DataKind Type { get; set; } = Microsoft.ML.Data.DataKind.TX; internal override string ComponentName => "SimplePathParser"; } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs b/src/Microsoft.ML/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs index 289adc6f75..7950975a22 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/JsonUtils/JsonManifestUtils.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Reflection; using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Internal.Tools; using Microsoft.ML.Runtime.Internal.Utilities; using Newtonsoft.Json.Linq; @@ -67,13 +68,7 @@ public static JObject BuildAllManifests(IExceptionContext ectx, ModuleCatalog ca { var jField = new JObject(); jField[FieldNames.Name] = fieldInfo.Name; - var type = fieldInfo.PropertyType; - // Dive inside Optional. - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; - // Dive inside Nullable. - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - type = type.GetGenericArguments()[0]; + var type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(fieldInfo.PropertyType); // Dive inside Var. if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var<>)) type = type.GetGenericArguments()[0]; @@ -308,14 +303,7 @@ private static JToken BuildTypeToken(IExceptionContext ectx, FieldInfo fieldInfo jo[FieldNames.ItemType] = typeString; return jo; } - - // Dive inside Optional. - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; - - // Dive inside Nullable. - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - type = type.GetGenericArguments()[0]; + type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(type); // Dive inside Var. if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var<>)) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 234c87fade..b1f344873d 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -3,11 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using System.CodeDom; using System.Collections.Generic; using System.IO; using System.Linq; -using Microsoft.CSharp; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -34,381 +32,13 @@ public sealed class Arguments public string[] Exclude; } - private static class GeneratorUtils - { - public static string GetFullMethodName(ModuleCatalog.EntryPointInfo entryPointInfo) - { - return entryPointInfo.Name; - } - - public static Tuple GetClassAndMethodNames(ModuleCatalog.EntryPointInfo entryPointInfo) - { - var split = entryPointInfo.Name.Split('.'); - Contracts.Assert(split.Length == 2); - return new Tuple(split[0], split[1]); - } - - public static string GetCSharpTypeName(Type type) - { - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - return GetCSharpTypeName(type.GetGenericArguments()[0]) + "?"; - - string name; - using (var p = new CSharpCodeProvider()) - name = p.GetTypeOutput(new CodeTypeReference(type)); - return name; - } - - public static string GetOutputType(Type outputType) - { - Contracts.Check(Var.CheckType(outputType)); - - if (outputType.IsArray) - return $"ArrayVar<{GetCSharpTypeName(outputType.GetElementType())}>"; - if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Dictionary<,>) - && outputType.GetGenericTypeArgumentsEx()[0] == typeof(string)) - { - return $"DictionaryVar<{GetCSharpTypeName(outputType.GetGenericTypeArgumentsEx()[1])}>"; - } - - return $"Var<{GetCSharpTypeName(outputType)}>"; - } - - public static string GetInputType(ModuleCatalog catalog, Type inputType, - Dictionary typesSymbolTable, string rootNameSpace = "") - { - if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Var<>)) - return $"Var<{GetCSharpTypeName(inputType.GetGenericTypeArgumentsEx()[0])}>"; - - if (inputType.IsArray && Var.CheckType(inputType.GetElementType())) - return $"ArrayVar<{GetCSharpTypeName(inputType.GetElementType())}>"; - - if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Dictionary<,>) - && inputType.GetGenericTypeArgumentsEx()[0] == typeof(string)) - { - return $"DictionaryVar<{GetCSharpTypeName(inputType.GetGenericTypeArgumentsEx()[1])}>"; - } - - if (Var.CheckType(inputType)) - return $"Var<{GetCSharpTypeName(inputType)}>"; - - bool isNullable = false; - bool isOptional = false; - var type = inputType; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - { - type = type.GetGenericArguments()[0]; - isNullable = true; - } - else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - { - type = type.GetGenericArguments()[0]; - isOptional = true; - } - - var typeEnum = TlcModule.GetDataType(type); - switch (typeEnum) - { - case TlcModule.DataKind.Float: - case TlcModule.DataKind.Int: - case TlcModule.DataKind.UInt: - case TlcModule.DataKind.Char: - case TlcModule.DataKind.String: - case TlcModule.DataKind.Bool: - case TlcModule.DataKind.DataView: - case TlcModule.DataKind.TransformModel: - case TlcModule.DataKind.PredictorModel: - case TlcModule.DataKind.FileHandle: - return GetCSharpTypeName(inputType); - case TlcModule.DataKind.Array: - return GetInputType(catalog, inputType.GetElementType(), typesSymbolTable) + "[]"; - case TlcModule.DataKind.Component: - string kind; - bool success = catalog.TryGetComponentKind(type, out kind); - Contracts.Assert(success); - return $"{kind}"; - case TlcModule.DataKind.Enum: - var enumName = GetEnumName(type, typesSymbolTable, rootNameSpace); - if (isNullable) - return $"{enumName}?"; - if (isOptional) - return $"Optional<{enumName}>"; - return $"{enumName}"; - default: - if (isNullable) - return rootNameSpace + typesSymbolTable[type.FullName]; - if (isOptional) - return $"Optional<{rootNameSpace + typesSymbolTable[type.FullName]}>"; - if (typesSymbolTable.ContainsKey(type.FullName)) - return rootNameSpace + typesSymbolTable[type.FullName]; - else - return GetSymbolFromType(typesSymbolTable, type, rootNameSpace); - } - } - - public static bool IsComponent(Type inputType) - { - if (inputType.IsArray && Var.CheckType(inputType.GetElementType())) - return false; - - if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Dictionary<,>) - && inputType.GetGenericTypeArgumentsEx()[0] == typeof(string)) - { - return false; - } - - if (Var.CheckType(inputType)) - return false; - - var type = inputType; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - type = type.GetGenericArguments()[0]; - else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; - - var typeEnum = TlcModule.GetDataType(type); - return typeEnum == TlcModule.DataKind.Component; - } - - public static string Capitalize(string s) - { - if (string.IsNullOrEmpty(s)) - return s; - return char.ToUpperInvariant(s[0]) + s.Substring(1); - } - - private static string GetCharAsString(char value) - { - switch (value) - { - case '\t': - return "\\t"; - case '\n': - return "\\n"; - case '\r': - return "\\r"; - case '\\': - return "\\"; - case '\"': - return "\""; - case '\'': - return "\\'"; - case '\0': - return "\\0"; - case '\a': - return "\\a"; - case '\b': - return "\\b"; - case '\f': - return "\\f"; - case '\v': - return "\\v"; - default: - return value.ToString(); - } - } - - public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, - Dictionary typesSymbolTable, string rootNameSpace = "") - { - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Var<>)) - return $"new Var<{GetCSharpTypeName(fieldType.GetGenericTypeArgumentsEx()[0])}>()"; - - if (fieldType.IsArray && Var.CheckType(fieldType.GetElementType())) - return $"new ArrayVar<{GetCSharpTypeName(fieldType.GetElementType())}>()"; - - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Dictionary<,>) - && fieldType.GetGenericTypeArgumentsEx()[0] == typeof(string)) - { - return $"new DictionaryVar<{GetCSharpTypeName(fieldType.GetGenericTypeArgumentsEx()[1])}>()"; - } - - if (Var.CheckType(fieldType)) - return $"new Var<{GetCSharpTypeName(fieldType)}>()"; - - if (fieldValue == null) - return null; - - if (!fieldType.IsInterface) - { - try - { - var defaultFieldValue = Activator.CreateInstance(fieldType); - if (defaultFieldValue == fieldValue) - return null; - } - catch (MissingMethodException) - { - // No parameterless constructor, ignore. - } - } - - var typeEnum = TlcModule.GetDataType(fieldType); - if (fieldType.IsGenericType && (fieldType.GetGenericTypeDefinition() == typeof(Optional<>) || fieldType.GetGenericTypeDefinition() == typeof(Nullable<>))) - fieldType = fieldType.GetGenericArguments()[0]; - switch (typeEnum) - { - case TlcModule.DataKind.Array: - var arr = fieldValue as Array; - if (arr != null && arr.GetLength(0) > 0) - return $"{{ {string.Join(", ", arr.Cast().Select(item => GetValue(catalog, fieldType.GetElementType(), item, typesSymbolTable)))} }}"; - return null; - case TlcModule.DataKind.String: - var strval = fieldValue as string; - if (strval != null) - return Quote(strval); - return null; - case TlcModule.DataKind.Float: - if (fieldValue is double d) - { - if (double.IsPositiveInfinity(d)) - return "double.PositiveInfinity"; - if (double.IsNegativeInfinity(d)) - return "double.NegativeInfinity"; - if (d != 0) - return d.ToString("R") + "d"; - } - else if (fieldValue is float f) - { - if (float.IsPositiveInfinity(f)) - return "float.PositiveInfinity"; - if (float.IsNegativeInfinity(f)) - return "float.NegativeInfinity"; - if (f != 0) - return f.ToString("R") + "f"; - } - return null; - case TlcModule.DataKind.Int: - if (fieldValue is int i) - { - if (i != 0) - return i.ToString(); - } - else if (fieldValue is long l) - { - if (l != 0) - return l.ToString(); - } - return null; - case TlcModule.DataKind.Bool: - return (bool)fieldValue ? "true" : "false"; - case TlcModule.DataKind.Enum: - return GetEnumName(fieldType, typesSymbolTable, rootNameSpace) + "." + fieldValue; - case TlcModule.DataKind.Char: - return $"'{GetCharAsString((char)fieldValue)}'"; - case TlcModule.DataKind.Component: - var type = fieldValue.GetType(); - ModuleCatalog.ComponentInfo componentInfo; - if (!catalog.TryFindComponent(fieldType, type, out componentInfo)) - return null; - object defaultComponent = null; - try - { - defaultComponent = Activator.CreateInstance(componentInfo.ArgumentType); - } - catch (MissingMethodException) - { - // No parameterless constructor, ignore. - } - var propertyBag = new List(); - if (defaultComponent != null) - { - foreach (var fieldInfo in componentInfo.ArgumentType.GetFields()) - { - var inputAttr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute; - if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) - continue; - if (fieldInfo.FieldType == typeof(JArray) || fieldInfo.FieldType == typeof(JObject)) - continue; - - var propertyValue = GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(fieldValue), typesSymbolTable); - var defaultPropertyValue = GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaultComponent), typesSymbolTable); - if (propertyValue != defaultPropertyValue) - propertyBag.Add($"{GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} = {propertyValue}"); - } - } - var properties = propertyBag.Count > 0 ? $" {{ {string.Join(", ", propertyBag)} }}" : ""; - return $"new {GetComponentName(componentInfo)}(){properties}"; - case TlcModule.DataKind.Unknown: - return $"new {rootNameSpace + typesSymbolTable[fieldType.FullName]}()"; - default: - return fieldValue.ToString(); - } - } - - private static string Quote(string src) - { - var dst = src.Replace("\\", @"\\").Replace("\"", "\\\"").Replace("\n", @"\n").Replace("\r", @"\r"); - return "\"" + dst + "\""; - } - - public static string GetComponentName(ModuleCatalog.ComponentInfo component) - { - return $"{Capitalize(component.Name)}{component.Kind}"; - } - - public static string GetEnumName(Type type, Dictionary typesSymbolTable, string rootNamespace = "") - { - if (typesSymbolTable.ContainsKey(type.FullName)) - return rootNamespace + typesSymbolTable[type.FullName]; - else - return GetSymbolFromType(typesSymbolTable, type, rootNamespace); - } - - public static string GetJsonFromField(string fieldName, Type fieldType) - { - if (fieldType.IsArray && Var.CheckType(fieldType.GetElementType())) - return $"{{({fieldName}.IsValue ? {fieldName}.VarName : $\"'${{{fieldName}.VarName}}'\")}}"; - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Dictionary<,>) - && fieldType.GetGenericTypeArgumentsEx()[0] == typeof(string)) - { - return $"'${{{fieldName}.VarName}}'"; - } - if (Var.CheckType(fieldType)) - return $"'${{{fieldName}.VarName}}'"; - - var isNullable = false; - var type = fieldType; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - { - type = type.GetGenericArguments()[0]; - isNullable = true; - } - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; - - var typeEnum = TlcModule.GetDataType(type); - switch (typeEnum) - { - default: - if (isNullable) - return $"{{(!{fieldName}.HasValue ? \"null\" : $\"{{{fieldName}.Value}}\")}}"; - return $"{{{fieldName}}}"; - case TlcModule.DataKind.Enum: - if (isNullable) - return $"{{(!{fieldName}.HasValue ? \"null\" : $\"'{{{fieldName}.Value}}'\")}}"; - return $"'{{{fieldName}}}'"; - case TlcModule.DataKind.String: - return $"{{({fieldName} == null ? \"null\" : $\"'{{{fieldName}}}'\")}}"; - case TlcModule.DataKind.Bool: - if (isNullable) - return $"{{(!{fieldName}.HasValue ? \"null\" : {fieldName}.Value ? \"true\" : \"false\")}}"; - return $"'{{({fieldName} ? \"true\" : \"false\")}}'"; - case TlcModule.DataKind.Component: - case TlcModule.DataKind.Unknown: - return $"{{({fieldName} == null ? \"null\" : {fieldName}.ToJson())}}"; - case TlcModule.DataKind.Array: - return $"[{{({fieldName} == null ? \"\" : string.Join(\",\", {fieldName}.Select(f => $\"{GetJsonFromField("f", type.GetElementType())}\")))}}]"; - } - } - } - private readonly IHost _host; private readonly string _csFilename; private readonly string _regenerate; private readonly HashSet _excludedSet; private const string RegistrationName = "CSharpApiGenerator"; - public Dictionary TypesSymbolTable = new Dictionary(); + private const string _defaultNamespace = "Microsoft.ML."; + private readonly GeneratedClasses _generatedClasses; public CSharpApiGenerator(IHostEnvironment env, Arguments args, string regenerate) { @@ -423,6 +53,7 @@ public CSharpApiGenerator(IHostEnvironment env, Arguments args, string regenerat _csFilename = "CSharpApi.cs"; _regenerate = regenerate; _excludedSet = new HashSet(args.Exclude); + _generatedClasses = new GeneratedClasses(); } public void Generate(IEnumerable infos) @@ -434,17 +65,17 @@ public void Generate(IEnumerable infos) var writer = IndentingTextWriter.Wrap(sw, " "); // Generate header - GenerateHeader(writer); + CSharpGeneratorUtils.GenerateHeader(writer); foreach (var entryPointInfo in catalog.AllEntryPoints().Where(x => !_excludedSet.Contains(x.Name)).OrderBy(x => x.Name)) { // Generate method - GenerateMethod(writer, entryPointInfo, catalog); + CSharpGeneratorUtils.GenerateMethod(writer, entryPointInfo.Name, _defaultNamespace); } // Generate footer - GenerateFooter(writer); - GenerateFooter(writer); + CSharpGeneratorUtils.GenerateFooter(writer); + CSharpGeneratorUtils.GenerateFooter(writer); foreach (var entryPointInfo in catalog.AllEntryPoints().Where(x => !_excludedSet.Contains(x.Name)).OrderBy(x => x.Name)) { @@ -456,68 +87,28 @@ public void Generate(IEnumerable infos) writer.WriteLine("{"); writer.Indent(); - foreach (var kind in catalog.GetAllComponentKinds().OrderBy(x => x)) + foreach (var kind in catalog.GetAllComponentKinds()) { // Generate kind base class GenerateComponentKind(writer, kind); - foreach (var component in catalog.GetAllComponents(kind).OrderBy(x => x.Name)) + foreach (var component in catalog.GetAllComponents(kind)) { // Generate component GenerateComponent(writer, component, catalog); } } - GenerateFooter(writer); - GenerateFooter(writer); + CSharpGeneratorUtils.GenerateFooter(writer); + CSharpGeneratorUtils.GenerateFooter(writer); writer.WriteLine("#pragma warning restore"); } } - private void GenerateHeader(IndentingTextWriter writer) + private void GenerateInputOutput(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, ModuleCatalog catalog) { - writer.WriteLine("//------------------------------------------------------------------------------"); - writer.WriteLine("// "); - writer.WriteLine("// This code was generated by a tool."); - writer.WriteLine("//"); - writer.WriteLine("// Changes to this file may cause incorrect behavior and will be lost if"); - writer.WriteLine("// the code is regenerated."); - writer.WriteLine("// "); - writer.WriteLine("//------------------------------------------------------------------------------"); - //writer.WriteLine($"// This file is auto generated. To regenerate it, run: {_regenerate}"); - writer.WriteLine("#pragma warning disable"); - writer.WriteLine("using System.Collections.Generic;"); - writer.WriteLine("using Microsoft.ML.Runtime;"); - writer.WriteLine("using Microsoft.ML.Runtime.Data;"); - writer.WriteLine("using Microsoft.ML.Runtime.EntryPoints;"); - writer.WriteLine("using Newtonsoft.Json;"); - writer.WriteLine("using System;"); - writer.WriteLine("using System.Linq;"); - writer.WriteLine("using Microsoft.ML.Runtime.CommandLine;"); - writer.WriteLine(); - writer.WriteLine("namespace Microsoft.ML"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine("namespace Runtime"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine("public sealed partial class Experiment"); - writer.WriteLine("{"); - writer.Indent(); - } - - private void GenerateFooter(IndentingTextWriter writer) - { - writer.Outdent(); - writer.WriteLine("}"); - } - - private void GenerateInputOutput(IndentingTextWriter writer, - ModuleCatalog.EntryPointInfo entryPointInfo, - ModuleCatalog catalog) - { - var classAndMethod = GeneratorUtils.GetClassAndMethodNames(entryPointInfo); - writer.WriteLine($"namespace {classAndMethod.Item1}"); + var classAndMethod = CSharpGeneratorUtils.GetEntryPointMetadata(entryPointInfo); + writer.WriteLine($"namespace {classAndMethod.Namespace}"); writer.WriteLine("{"); writer.Indent(); GenerateInput(writer, entryPointInfo, catalog); @@ -526,78 +117,6 @@ private void GenerateInputOutput(IndentingTextWriter writer, writer.WriteLine(); } - /// - /// This methods creates a unique name for a class/struct/enum, given a type and a namespace. - /// It generates the name based on the property of the type - /// (see description here https://msdn.microsoft.com/en-us/library/system.type.fullname(v=vs.110).aspx). - /// Example: Assume we have the following structure in namespace X.Y: - /// class A { - /// class B { - /// enum C { - /// Value1, - /// Value2 - /// } - /// } - /// } - /// The full name of C would be X.Y.A+B+C. This method will generate the name "ABC" from it. In case - /// A is generic with one generic type, then the full name of typeof(A<float>.B.C) would be X.Y.A`1+B+C[[System.Single]]. - /// In this case, this method will generate the name "ASingleBC". - /// - /// A dictionary containing the names of the classes already generated. - /// This parameter is only used to ensure that the newly generated name is unique. - /// The type for which to generate the new name. - /// The namespace prefix to the new name. - /// A unique name derived from the given type and namespace. - private static string GetSymbolFromType(Dictionary typesSymbolTable, Type type, string currentNamespace) - { - var fullTypeName = type.FullName; - string name = currentNamespace != "" ? currentNamespace + '.' : ""; - - int bracketIndex = fullTypeName.IndexOf('['); - Type[] genericTypes = null; - if (type.IsGenericType) - genericTypes = type.GetGenericArguments(); - if (bracketIndex > 0) - { - Contracts.AssertValue(genericTypes); - fullTypeName = fullTypeName.Substring(0, bracketIndex); - } - - // When the type is nested, the names of the outer types are concatenated with a '+'. - var nestedNames = fullTypeName.Split('+'); - var baseName = nestedNames[0]; - - // We currently only handle generic types in the outer most class, support for generic inner classes - // can be added if needed. - int backTickIndex = baseName.LastIndexOf('`'); - int dotIndex = baseName.LastIndexOf('.'); - Contracts.Assert(dotIndex >= 0); - if (backTickIndex < 0) - name += baseName.Substring(dotIndex + 1); - else - { - name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1); - Contracts.AssertValue(genericTypes); - if (genericTypes != null) - { - foreach (var genType in genericTypes) - { - var splitNames = genType.FullName.Split('+'); - if (splitNames[0].LastIndexOf('.') >= 0) - splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1); - name += string.Join("", splitNames); - } - } - } - - for (int i = 1; i < nestedNames.Length; i++) - name += nestedNames[i]; - - Contracts.Assert(typesSymbolTable.Select(kvp => kvp.Value).All(str => string.Compare(str, name) != 0)); - - return "Microsoft.ML." + name; - } - private void GenerateEnums(IndentingTextWriter writer, Type inputType, string currentNamespace) { foreach (var fieldInfo in inputType.GetFields()) @@ -605,14 +124,8 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu var inputAttr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute; if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) continue; - - var type = fieldInfo.FieldType; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - type = type.GetGenericArguments()[0]; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; - - if (TypesSymbolTable.ContainsKey(type.FullName)) + var type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(fieldInfo.FieldType); + if (_generatedClasses.IsGenerated(type.FullName)) continue; if (!type.IsEnum) @@ -625,15 +138,16 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu var enumType = Enum.GetUnderlyingType(type); - TypesSymbolTable[type.FullName] = GetSymbolFromType(TypesSymbolTable, type, currentNamespace); + var apiName = _generatedClasses.GetApiName(type, currentNamespace); if (enumType == typeof(int)) - writer.WriteLine($"public enum {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}"); + writer.WriteLine($"public enum {apiName}"); else { Contracts.Assert(enumType == typeof(byte)); - writer.WriteLine($"public enum {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)} : byte"); + writer.WriteLine($"public enum {apiName} : byte"); } + _generatedClasses.MarkAsGenerated(type.FullName); writer.Write("{"); writer.Indent(); var names = Enum.GetNames(type); @@ -660,25 +174,7 @@ private void GenerateEnums(IndentingTextWriter writer, Type inputType, string cu } } - string GetFriendlyTypeName(string currentNameSpace, string typeName) - { - Contracts.Assert(typeName.Length >= currentNameSpace.Length); - - int index = 0; - for (index = 0; index < currentNameSpace.Length && currentNameSpace[index] == typeName[index]; index++) ; - - if (index == 0) - return typeName; - if (typeName[index - 1] == '.') - return typeName.Substring(index); - - return typeName; - } - - private void GenerateStructs(IndentingTextWriter writer, - Type inputType, - ModuleCatalog catalog, - string currentNamespace) + private void GenerateClasses(IndentingTextWriter writer, Type inputType, ModuleCatalog catalog, string currentNamespace) { foreach (var fieldInfo in inputType.GetFields()) { @@ -687,10 +183,7 @@ private void GenerateStructs(IndentingTextWriter writer, continue; var type = fieldInfo.FieldType; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - type = type.GetGenericArguments()[0]; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; + type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(type); if (type.IsArray) type = type.GetElementType(); if (type == typeof(JArray) || type == typeof(JObject)) @@ -707,265 +200,210 @@ private void GenerateStructs(IndentingTextWriter writer, if (typeEnum != TlcModule.DataKind.Unknown) continue; - if (TypesSymbolTable.ContainsKey(type.FullName)) + if (_generatedClasses.IsGenerated(type.FullName)) continue; + GenerateEnums(writer, type, currentNamespace); + GenerateClasses(writer, type, catalog, currentNamespace); - TypesSymbolTable[type.FullName] = GetSymbolFromType(TypesSymbolTable, type, currentNamespace); + var apiName = _generatedClasses.GetApiName(type, currentNamespace); string classBase = ""; if (type.IsSubclassOf(typeof(OneToOneColumn))) - classBase = $" : OneToOneColumn<{TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IOneToOneColumn"; + classBase = $" : OneToOneColumn<{apiName}>, IOneToOneColumn"; else if (type.IsSubclassOf(typeof(ManyToOneColumn))) - classBase = $" : ManyToOneColumn<{TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}>, IManyToOneColumn"; - writer.WriteLine($"public sealed partial class {TypesSymbolTable[type.FullName].Substring(TypesSymbolTable[type.FullName].LastIndexOf('.') + 1)}{classBase}"); + classBase = $" : ManyToOneColumn<{apiName}>, IManyToOneColumn"; + writer.WriteLine($"public sealed partial class {apiName}{classBase}"); writer.WriteLine("{"); writer.Indent(); - GenerateInputFields(writer, type, catalog, TypesSymbolTable); + _generatedClasses.MarkAsGenerated(type.FullName); + GenerateInputFields(writer, type, catalog, currentNamespace); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(); - GenerateStructs(writer, type, catalog, currentNamespace); } } - private void GenerateLoaderAddInputMethod(IndentingTextWriter writer, string className) + private void GenerateColumnAddMethods(IndentingTextWriter writer, Type inputType, ModuleCatalog catalog, + string className, out Type columnType) { - //Constructor. - writer.WriteLine("[JsonIgnore]"); - writer.WriteLine("private string _inputFilePath = null;"); - writer.WriteLine($"public {className}(string filePath)"); + columnType = null; + foreach (var fieldInfo in inputType.GetFields()) + { + var inputAttr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute; + if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) + continue; + var type = CSharpGeneratorUtils.ExtractOptionalOrNullableType(fieldInfo.FieldType); + var isArray = type.IsArray; + if (isArray) + type = type.GetElementType(); + if (type == typeof(JArray) || type == typeof(JObject)) + continue; + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var<>)) + continue; + var typeEnum = TlcModule.GetDataType(type); + if (typeEnum != TlcModule.DataKind.Unknown) + continue; + + if (type.IsSubclassOf(typeof(OneToOneColumn))) + columnType = GenerateOneToOneColumn(writer, className, columnType, fieldInfo, inputAttr, type, isArray); + else if (type.IsSubclassOf(typeof(ManyToOneColumn))) + columnType = GenerateManyToOneColumn(writer, className, columnType, fieldInfo, inputAttr, type, isArray); + } + } + + private Type GenerateManyToOneColumn(IndentingTextWriter writer, string className, Type columnType, + System.Reflection.FieldInfo fieldInfo, ArgumentAttribute inputAttr, Type type, bool isArray) + { + var fieldName = CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name); + var apiName = _generatedClasses.GetApiName(type, ""); + writer.WriteLine($"public {className}()"); + writer.WriteLine("{"); + writer.WriteLine("}"); + writer.WriteLine(""); + writer.WriteLine($"public {className}(string output{fieldName}, params string[] input{fieldName}s)"); writer.WriteLine("{"); writer.Indent(); - writer.WriteLine("_inputFilePath = filePath;"); + writer.WriteLine($"Add{fieldName}(output{fieldName}, input{fieldName}s);"); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(""); - - //SetInput. - writer.WriteLine($"public void SetInput(IHostEnvironment env, Experiment experiment)"); + writer.WriteLine($"public void Add{fieldName}(string name, params string[] source)"); writer.WriteLine("{"); writer.Indent(); - writer.WriteLine("IFileHandle inputFile = new SimpleFileHandle(env, _inputFilePath, false, false);"); - writer.WriteLine("experiment.SetInput(InputFile, inputFile);"); + if (isArray) + { + writer.WriteLine($"var list = {fieldName} == null ? new List<{apiName}>() : new List<{apiName}>({fieldName});"); + writer.WriteLine($"list.Add(ManyToOneColumn<{apiName}>.Create(name, source));"); + writer.WriteLine($"{fieldName} = list.ToArray();"); + } + else + writer.WriteLine($"{fieldName} = ManyToOneColumn<{apiName}>.Create(name, source);"); writer.Outdent(); writer.WriteLine("}"); - writer.WriteLine(""); + writer.WriteLine(); - //GetInputData - writer.WriteLine("public Var GetInputData() => null;"); - writer.WriteLine(""); + Contracts.Assert(columnType == null); + + columnType = type; + return columnType; + } - //Apply. - writer.WriteLine($"public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)"); + private Type GenerateOneToOneColumn(IndentingTextWriter writer, string className, Type columnType, + System.Reflection.FieldInfo fieldInfo, ArgumentAttribute inputAttr, Type type, bool isArray) + { + var fieldName = CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name); + var generatedType = _generatedClasses.GetApiName(type, ""); + writer.WriteLine($"public {className}()"); writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine("Contracts.Assert(previousStep == null);"); + writer.WriteLine("}"); writer.WriteLine(""); - writer.WriteLine($"return new {className}PipelineStep(experiment.Add(this));"); + writer.WriteLine($"public {className}(params string[] input{fieldName}s)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"if (input{fieldName}s != null)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"foreach (string input in input{fieldName}s)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"Add{fieldName}(input);"); + writer.Outdent(); + writer.WriteLine("}"); + writer.Outdent(); + writer.WriteLine("}"); writer.Outdent(); writer.WriteLine("}"); writer.WriteLine(""); - - //Pipelinestep class. - writer.WriteLine($"private class {className}PipelineStep : ILearningPipelineDataStep"); + writer.WriteLine($"public {className}(params ValueTuple[] inputOutput{fieldName}s)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"if (inputOutput{fieldName}s != null)"); writer.WriteLine("{"); writer.Indent(); - writer.WriteLine($"public {className}PipelineStep (Output output)"); + writer.WriteLine($"foreach (ValueTuple inputOutput in inputOutput{fieldName}s)"); writer.WriteLine("{"); writer.Indent(); - writer.WriteLine("Data = output.Data;"); - writer.WriteLine("Model = null;"); + writer.WriteLine($"Add{fieldName}(inputOutput.Item2, inputOutput.Item1);"); writer.Outdent(); writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine("public Var Data { get; }"); - writer.WriteLine("public Var Model { get; }"); writer.Outdent(); writer.WriteLine("}"); - } - - private void GenerateColumnAddMethods(IndentingTextWriter writer, - Type inputType, - ModuleCatalog catalog, - string className, - out Type columnType) - { - columnType = null; - foreach (var fieldInfo in inputType.GetFields()) + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + writer.WriteLine($"public void Add{fieldName}(string source)"); + writer.WriteLine("{"); + writer.Indent(); + if (isArray) { - var inputAttr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute; - if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) - continue; - - var type = fieldInfo.FieldType; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - type = type.GetGenericArguments()[0]; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) - type = type.GetGenericArguments()[0]; - var isArray = type.IsArray; - if (isArray) - type = type.GetElementType(); - if (type == typeof(JArray) || type == typeof(JObject)) - continue; - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Var<>)) - continue; - var typeEnum = TlcModule.GetDataType(type); - if (typeEnum != TlcModule.DataKind.Unknown) - continue; - - if (type.IsSubclassOf(typeof(OneToOneColumn))) - { - var fieldName = GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name); - writer.WriteLine($"public {className}()"); - writer.WriteLine("{"); - writer.WriteLine("}"); - writer.WriteLine(""); - writer.WriteLine($"public {className}(params string[] input{fieldName}s)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"if (input{fieldName}s != null)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"foreach (string input in input{fieldName}s)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"Add{fieldName}(input);"); - writer.Outdent(); - writer.WriteLine("}"); - writer.Outdent(); - writer.WriteLine("}"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(""); - writer.WriteLine($"public {className}(params ValueTuple[] inputOutput{fieldName}s)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"if (inputOutput{fieldName}s != null)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"foreach (ValueTuple inputOutput in inputOutput{fieldName}s)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"Add{fieldName}(inputOutput.Item2, inputOutput.Item1);"); - writer.Outdent(); - writer.WriteLine("}"); - writer.Outdent(); - writer.WriteLine("}"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(""); - writer.WriteLine($"public void Add{fieldName}(string source)"); - writer.WriteLine("{"); - writer.Indent(); - if (isArray) - { - writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});"); - writer.WriteLine($"list.Add(OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(source));"); - writer.WriteLine($"{fieldName} = list.ToArray();"); - } - else - writer.WriteLine($"{fieldName} = OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(source);"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine($"public void Add{fieldName}(string name, string source)"); - writer.WriteLine("{"); - writer.Indent(); - if (isArray) - { - writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});"); - writer.WriteLine($"list.Add(OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source));"); - writer.WriteLine($"{fieldName} = list.ToArray();"); - } - else - writer.WriteLine($"{fieldName} = OneToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source);"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(); - - Contracts.Assert(columnType == null); - - columnType = type; - } - else if (type.IsSubclassOf(typeof(ManyToOneColumn))) - { - var fieldName = GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name); - writer.WriteLine($"public {className}()"); - writer.WriteLine("{"); - writer.WriteLine("}"); - writer.WriteLine(""); - writer.WriteLine($"public {className}(string output{fieldName}, params string[] input{fieldName}s)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"Add{fieldName}(output{fieldName}, input{fieldName}s);"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(""); - writer.WriteLine($"public void Add{fieldName}(string name, params string[] source)"); - writer.WriteLine("{"); - writer.Indent(); - if (isArray) - { - writer.WriteLine($"var list = {fieldName} == null ? new List<{TypesSymbolTable[type.FullName]}>() : new List<{TypesSymbolTable[type.FullName]}>({fieldName});"); - writer.WriteLine($"list.Add(ManyToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source));"); - writer.WriteLine($"{fieldName} = list.ToArray();"); - } - else - writer.WriteLine($"{fieldName} = ManyToOneColumn<{TypesSymbolTable[type.FullName]}>.Create(name, source);"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(); + writer.WriteLine($"var list = {fieldName} == null ? new List<{generatedType}>() : new List<{generatedType}>({fieldName});"); + writer.WriteLine($"list.Add(OneToOneColumn<{generatedType}>.Create(source));"); + writer.WriteLine($"{fieldName} = list.ToArray();"); + } + else + writer.WriteLine($"{fieldName} = OneToOneColumn<{generatedType}>.Create(source);"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine($"public void Add{fieldName}(string name, string source)"); + writer.WriteLine("{"); + writer.Indent(); + if (isArray) + { + writer.WriteLine($"var list = {fieldName} == null ? new List<{generatedType}>() : new List<{generatedType}>({fieldName});"); + writer.WriteLine($"list.Add(OneToOneColumn<{generatedType}>.Create(name, source));"); + writer.WriteLine($"{fieldName} = list.ToArray();"); + } + else + writer.WriteLine($"{fieldName} = OneToOneColumn<{generatedType}>.Create(name, source);"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(); - Contracts.Assert(columnType == null); + Contracts.Assert(columnType == null); - columnType = type; - } - } + columnType = type; + return columnType; } - private void GenerateInput(IndentingTextWriter writer, - ModuleCatalog.EntryPointInfo entryPointInfo, - ModuleCatalog catalog) + private void GenerateInput(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, ModuleCatalog catalog) { - var classAndMethod = GeneratorUtils.GetClassAndMethodNames(entryPointInfo); + var entryPointMetadata = CSharpGeneratorUtils.GetEntryPointMetadata(entryPointInfo); string classBase = ""; if (entryPointInfo.InputKinds != null) { - classBase += $" : {string.Join(", ", entryPointInfo.InputKinds.Select(GeneratorUtils.GetCSharpTypeName))}"; + classBase += $" : {string.Join(", ", entryPointInfo.InputKinds.Select(CSharpGeneratorUtils.GetCSharpTypeName))}"; if (entryPointInfo.InputKinds.Any(t => typeof(ITrainerInput).IsAssignableFrom(t) || typeof(ITransformInput).IsAssignableFrom(t))) classBase += ", Microsoft.ML.ILearningPipelineItem"; } - GenerateEnums(writer, entryPointInfo.InputType, classAndMethod.Item1); + GenerateEnums(writer, entryPointInfo.InputType, _defaultNamespace + entryPointMetadata.Namespace); writer.WriteLine(); - GenerateStructs(writer, entryPointInfo.InputType, catalog, classAndMethod.Item1); - writer.WriteLine("/// "); - foreach (var line in entryPointInfo.Description.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries)) - writer.WriteLine($"/// {line}"); - writer.WriteLine("/// "); + GenerateClasses(writer, entryPointInfo.InputType, catalog, _defaultNamespace + entryPointMetadata.Namespace); + CSharpGeneratorUtils.GenerateSummary(writer, entryPointInfo.Description); if (entryPointInfo.ObsoleteAttribute != null) writer.WriteLine($"[Obsolete(\"{entryPointInfo.ObsoleteAttribute.Message}\")]"); - writer.WriteLine($"public sealed partial class {classAndMethod.Item2}{classBase}"); + writer.WriteLine($"public sealed partial class {entryPointMetadata.ClassName}{classBase}"); writer.WriteLine("{"); writer.Indent(); writer.WriteLine(); if (entryPointInfo.InputKinds != null && entryPointInfo.InputKinds.Any(t => typeof(ILearningPipelineLoader).IsAssignableFrom(t))) - GenerateLoaderAddInputMethod(writer, classAndMethod.Item2); + CSharpGeneratorUtils.GenerateLoaderAddInputMethod(writer, entryPointMetadata.ClassName); - GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, classAndMethod.Item2, out Type transformType); + GenerateColumnAddMethods(writer, entryPointInfo.InputType, catalog, entryPointMetadata.ClassName, out Type transformType); writer.WriteLine(); - GenerateInputFields(writer, entryPointInfo.InputType, catalog, TypesSymbolTable); + GenerateInputFields(writer, entryPointInfo.InputType, catalog, _defaultNamespace + entryPointMetadata.Namespace); writer.WriteLine(); GenerateOutput(writer, entryPointInfo, out HashSet outputVariableNames); - GenerateApplyFunction(writer, entryPointInfo, transformType, outputVariableNames, entryPointInfo.InputKinds); + GenerateApplyFunction(writer, entryPointMetadata.ClassName, transformType, outputVariableNames, entryPointInfo.InputKinds); writer.Outdent(); writer.WriteLine("}"); } - private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, - Type type, HashSet outputVariableNames, Type[] inputKinds) + private static void GenerateApplyFunction(IndentingTextWriter writer, string className, Type type, + HashSet outputVariableNames, Type[] inputKinds) { if (inputKinds == null) return; @@ -987,7 +425,6 @@ private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCata writer.WriteLine("public Var GetInputData() => TrainingData;"); writer.WriteLine(""); - string className = GeneratorUtils.GetClassAndMethodNames(entryPointInfo).Item2; writer.WriteLine("public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)"); writer.WriteLine("{"); @@ -1057,8 +494,7 @@ private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCata writer.WriteLine("}"); } - private static void GenerateInputFields(IndentingTextWriter writer, - Type inputType, ModuleCatalog catalog, Dictionary typesSymbolTable, string rootNameSpace = "") + private void GenerateInputFields(IndentingTextWriter writer, Type inputType, ModuleCatalog catalog, string rootNameSpace) { var defaults = Activator.CreateInstance(inputType); foreach (var fieldInfo in inputType.GetFields()) @@ -1070,20 +506,18 @@ private static void GenerateInputFields(IndentingTextWriter writer, if (fieldInfo.FieldType == typeof(JObject)) continue; - writer.WriteLine("/// "); - writer.WriteLine($"/// {inputAttr.HelpText}"); - writer.WriteLine("/// "); + CSharpGeneratorUtils.GenerateSummary(writer, inputAttr.HelpText); if (fieldInfo.FieldType == typeof(JArray)) { - writer.WriteLine($"public Experiment {GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}"); + writer.WriteLine($"public Experiment {CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}"); writer.WriteLine(); continue; } - var inputTypeString = GeneratorUtils.GetInputType(catalog, fieldInfo.FieldType, typesSymbolTable, rootNameSpace); - if (GeneratorUtils.IsComponent(fieldInfo.FieldType)) + var inputTypeString = CSharpGeneratorUtils.GetInputType(catalog, fieldInfo.FieldType, _generatedClasses, rootNameSpace); + if (CSharpGeneratorUtils.IsComponent(fieldInfo.FieldType)) writer.WriteLine("[JsonConverter(typeof(ComponentSerializer))]"); - if (GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name) != (inputAttr.Name ?? fieldInfo.Name)) + if (CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name) != (inputAttr.Name ?? fieldInfo.Name)) writer.WriteLine($"[JsonProperty(\"{inputAttr.Name ?? fieldInfo.Name}\")]"); // For range attributes on properties @@ -1105,8 +539,8 @@ private static void GenerateInputFields(IndentingTextWriter writer, writer.WriteLine(sweepableParamAttr.ToString()); } - writer.Write($"public {inputTypeString} {GeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}"); - var defaultValue = GeneratorUtils.GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaults), typesSymbolTable, rootNameSpace); + writer.Write($"public {inputTypeString} {CSharpGeneratorUtils.Capitalize(inputAttr.Name ?? fieldInfo.Name)} {{ get; set; }}"); + var defaultValue = CSharpGeneratorUtils.GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaults), _generatedClasses, rootNameSpace); if (defaultValue != null) writer.Write($" = {defaultValue};"); writer.WriteLine(); @@ -1114,14 +548,12 @@ private static void GenerateInputFields(IndentingTextWriter writer, } } - private void GenerateOutput(IndentingTextWriter writer, - ModuleCatalog.EntryPointInfo entryPointInfo, - out HashSet outputVariableNames) + private void GenerateOutput(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, out HashSet outputVariableNames) { outputVariableNames = new HashSet(); string classBase = ""; if (entryPointInfo.OutputKinds != null) - classBase = $" : {string.Join(", ", entryPointInfo.OutputKinds.Select(GeneratorUtils.GetCSharpTypeName))}"; + classBase = $" : {string.Join(", ", entryPointInfo.OutputKinds.Select(CSharpGeneratorUtils.GetCSharpTypeName))}"; writer.WriteLine($"public sealed class Output{classBase}"); writer.WriteLine("{"); writer.Indent(); @@ -1136,12 +568,10 @@ private void GenerateOutput(IndentingTextWriter writer, if (outputAttr == null) continue; - writer.WriteLine("/// "); - writer.WriteLine($"/// {outputAttr.Desc}"); - writer.WriteLine("/// "); - var outputTypeString = GeneratorUtils.GetOutputType(fieldInfo.FieldType); - outputVariableNames.Add(GeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name)); - writer.WriteLine($"public {outputTypeString} {GeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name)} {{ get; set; }} = new {outputTypeString}();"); + CSharpGeneratorUtils.GenerateSummary(writer, outputAttr.Desc); + var outputTypeString = CSharpGeneratorUtils.GetOutputType(fieldInfo.FieldType); + outputVariableNames.Add(CSharpGeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name)); + writer.WriteLine($"public {outputTypeString} {CSharpGeneratorUtils.Capitalize(outputAttr.Name ?? fieldInfo.Name)} {{ get; set; }} = new {outputTypeString}();"); writer.WriteLine(); } @@ -1149,30 +579,6 @@ private void GenerateOutput(IndentingTextWriter writer, writer.WriteLine("}"); } - private void GenerateMethod(IndentingTextWriter writer, - ModuleCatalog.EntryPointInfo entryPointInfo, - ModuleCatalog catalog) - { - var inputOuputClassName = GeneratorUtils.GetFullMethodName(entryPointInfo); - inputOuputClassName = "Microsoft.ML." + inputOuputClassName; - writer.WriteLine($"public {inputOuputClassName}.Output Add({inputOuputClassName} input)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"var output = new {inputOuputClassName}.Output();"); - writer.WriteLine("Add(input, output);"); - writer.WriteLine("return output;"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(); - writer.WriteLine($"public void Add({inputOuputClassName} input, {inputOuputClassName}.Output output)"); - writer.WriteLine("{"); - writer.Indent(); - writer.WriteLine($"_jsonNodes.Add(Serialize(\"{entryPointInfo.Name}\", input, output));"); - writer.Outdent(); - writer.WriteLine("}"); - writer.WriteLine(); - } - private void GenerateComponentKind(IndentingTextWriter writer, string kind) { writer.WriteLine($"public abstract class {kind} : ComponentKind {{}}"); @@ -1183,15 +589,13 @@ private void GenerateComponent(IndentingTextWriter writer, ModuleCatalog.Compone { GenerateEnums(writer, component.ArgumentType, "Runtime"); writer.WriteLine(); - GenerateStructs(writer, component.ArgumentType, catalog, "Runtime"); + GenerateClasses(writer, component.ArgumentType, catalog, "Runtime"); writer.WriteLine(); - writer.WriteLine("/// "); - writer.WriteLine($"/// {component.Description}"); - writer.WriteLine("/// "); - writer.WriteLine($"public sealed class {GeneratorUtils.GetComponentName(component)} : {component.Kind}"); + CSharpGeneratorUtils.GenerateSummary(writer, component.Description); + writer.WriteLine($"public sealed class {CSharpGeneratorUtils.GetComponentName(component)} : {component.Kind}"); writer.WriteLine("{"); writer.Indent(); - GenerateInputFields(writer, component.ArgumentType, catalog, TypesSymbolTable, "Microsoft.ML."); + GenerateInputFields(writer, component.ArgumentType, catalog, "Runtime"); writer.WriteLine($"internal override string ComponentName => \"{component.Name}\";"); writer.Outdent(); writer.WriteLine("}"); diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpGeneratorUtils.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpGeneratorUtils.cs new file mode 100644 index 0000000000..11e2116b18 --- /dev/null +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpGeneratorUtils.cs @@ -0,0 +1,460 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.CodeDom; +using System.Collections.Generic; +using System.Linq; +using Microsoft.CSharp; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; +using Newtonsoft.Json.Linq; + +namespace Microsoft.ML.Runtime.Internal.Tools +{ + internal static class CSharpGeneratorUtils + { + public sealed class EntryPointGenerationMetadata + { + public string Namespace { get; } + public string ClassName { get; } + public EntryPointGenerationMetadata(string classNamespace, string className) + { + Namespace = classNamespace; + ClassName = className; + } + } + + public static EntryPointGenerationMetadata GetEntryPointMetadata(ModuleCatalog.EntryPointInfo entryPointInfo) + { + var split = entryPointInfo.Name.Split('.'); + Contracts.Check(split.Length == 2); + return new EntryPointGenerationMetadata(split[0], split[1]); + } + + public static Type ExtractOptionalOrNullableType(Type type) + { + if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Optional<>) || type.GetGenericTypeDefinition() == typeof(Nullable<>))) + type = type.GetGenericArguments()[0]; + + return type; + } + + public static Type ExtractOptionalOrNullableType(Type type, out bool isNullable, out bool isOptional) + { + isNullable = false; + isOptional = false; + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + type = type.GetGenericArguments()[0]; + isNullable = true; + } + else if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Optional<>)) + { + type = type.GetGenericArguments()[0]; + isOptional = true; + } + return type; + } + + public static string GetCSharpTypeName(Type type) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + return GetCSharpTypeName(type.GetGenericArguments()[0]) + "?"; + + using (var p = new CSharpCodeProvider()) + return p.GetTypeOutput(new CodeTypeReference(type)); + } + + public static string GetOutputType(Type outputType) + { + Contracts.Check(Var.CheckType(outputType)); + + if (outputType.IsArray) + return $"ArrayVar<{GetCSharpTypeName(outputType.GetElementType())}>"; + if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Dictionary<,>) + && outputType.GetGenericTypeArgumentsEx()[0] == typeof(string)) + { + return $"DictionaryVar<{GetCSharpTypeName(outputType.GetGenericTypeArgumentsEx()[1])}>"; + } + + return $"Var<{GetCSharpTypeName(outputType)}>"; + } + + public static string GetInputType(ModuleCatalog catalog, Type inputType, GeneratedClasses generatedClasses, string rootNameSpace) + { + if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Var<>)) + return $"Var<{GetCSharpTypeName(inputType.GetGenericTypeArgumentsEx()[0])}>"; + + if (inputType.IsArray && Var.CheckType(inputType.GetElementType())) + return $"ArrayVar<{GetCSharpTypeName(inputType.GetElementType())}>"; + + if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Dictionary<,>) + && inputType.GetGenericTypeArgumentsEx()[0] == typeof(string)) + { + return $"DictionaryVar<{GetCSharpTypeName(inputType.GetGenericTypeArgumentsEx()[1])}>"; + } + + if (Var.CheckType(inputType)) + return $"Var<{GetCSharpTypeName(inputType)}>"; + + var type = ExtractOptionalOrNullableType(inputType, out bool isNullable, out bool isOptional); + var typeEnum = TlcModule.GetDataType(type); + switch (typeEnum) + { + case TlcModule.DataKind.Float: + case TlcModule.DataKind.Int: + case TlcModule.DataKind.UInt: + case TlcModule.DataKind.Char: + case TlcModule.DataKind.String: + case TlcModule.DataKind.Bool: + case TlcModule.DataKind.DataView: + case TlcModule.DataKind.TransformModel: + case TlcModule.DataKind.PredictorModel: + case TlcModule.DataKind.FileHandle: + return GetCSharpTypeName(inputType); + case TlcModule.DataKind.Array: + return GetInputType(catalog, inputType.GetElementType(), generatedClasses, rootNameSpace) + "[]"; + case TlcModule.DataKind.Component: + string kind; + bool success = catalog.TryGetComponentKind(type, out kind); + Contracts.Assert(success); + return $"{kind}"; + case TlcModule.DataKind.Enum: + var enumName = generatedClasses.GetApiName(type, rootNameSpace); + if (isNullable) + return $"{enumName}?"; + if (isOptional) + return $"Optional<{enumName}>"; + return $"{enumName}"; + default: + if (isNullable) + return generatedClasses.GetApiName(type, rootNameSpace) + "?"; + if (isOptional) + return $"Optional<{generatedClasses.GetApiName(type, rootNameSpace)}>"; + return generatedClasses.GetApiName(type, rootNameSpace); + } + } + + public static bool IsComponent(Type inputType) + { + if (inputType.IsArray && Var.CheckType(inputType.GetElementType())) + return false; + + if (inputType.IsGenericType && inputType.GetGenericTypeDefinition() == typeof(Dictionary<,>) + && inputType.GetGenericTypeArgumentsEx()[0] == typeof(string)) + { + return false; + } + + if (Var.CheckType(inputType)) + return false; + + var type = ExtractOptionalOrNullableType(inputType); + var typeEnum = TlcModule.GetDataType(type); + return typeEnum == TlcModule.DataKind.Component; + } + + public static string Capitalize(string s) + { + if (string.IsNullOrEmpty(s)) + return s; + return char.ToUpperInvariant(s[0]) + s.Substring(1); + } + + private static string GetCharAsString(char value) + { + switch (value) + { + case '\t': + return "\\t"; + case '\n': + return "\\n"; + case '\r': + return "\\r"; + case '\\': + return "\\"; + case '\"': + return "\""; + case '\'': + return "\\'"; + case '\0': + return "\\0"; + case '\a': + return "\\a"; + case '\b': + return "\\b"; + case '\f': + return "\\f"; + case '\v': + return "\\v"; + default: + return value.ToString(); + } + } + + public static string GetValue(ModuleCatalog catalog, Type fieldType, object fieldValue, + GeneratedClasses generatedClasses, string rootNameSpace) + { + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Var<>)) + return $"new Var<{GetCSharpTypeName(fieldType.GetGenericTypeArgumentsEx()[0])}>()"; + + if (fieldType.IsArray && Var.CheckType(fieldType.GetElementType())) + return $"new ArrayVar<{GetCSharpTypeName(fieldType.GetElementType())}>()"; + + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Dictionary<,>) + && fieldType.GetGenericTypeArgumentsEx()[0] == typeof(string)) + { + return $"new DictionaryVar<{GetCSharpTypeName(fieldType.GetGenericTypeArgumentsEx()[1])}>()"; + } + + if (Var.CheckType(fieldType)) + return $"new Var<{GetCSharpTypeName(fieldType)}>()"; + + if (fieldValue == null) + return null; + + if (!fieldType.IsInterface) + { + try + { + var defaultFieldValue = Activator.CreateInstance(fieldType); + if (defaultFieldValue == fieldValue) + return null; + } + catch (MissingMethodException) + { + // No parameterless constructor, ignore. + } + } + + var typeEnum = TlcModule.GetDataType(fieldType); + fieldType = ExtractOptionalOrNullableType(fieldType, out bool isNullable, out bool isOptional); + switch (typeEnum) + { + case TlcModule.DataKind.Array: + var arr = fieldValue as Array; + if (arr != null && arr.GetLength(0) > 0) + return $"{{ {string.Join(", ", arr.Cast().Select(item => GetValue(catalog, fieldType.GetElementType(), item, generatedClasses, rootNameSpace)))} }}"; + return null; + case TlcModule.DataKind.String: + var strval = fieldValue as string; + if (strval != null) + return Quote(strval); + return null; + case TlcModule.DataKind.Float: + if (fieldValue is double d) + { + if (double.IsPositiveInfinity(d)) + return "double.PositiveInfinity"; + if (double.IsNegativeInfinity(d)) + return "double.NegativeInfinity"; + if (d != 0) + return d.ToString("R") + "d"; + } + else if (fieldValue is float f) + { + if (float.IsPositiveInfinity(f)) + return "float.PositiveInfinity"; + if (float.IsNegativeInfinity(f)) + return "float.NegativeInfinity"; + if (f != 0) + return f.ToString("R") + "f"; + } + return null; + case TlcModule.DataKind.Int: + if (fieldValue is int i) + { + if (i != 0) + return i.ToString(); + } + else if (fieldValue is long l) + { + if (l != 0) + return l.ToString(); + } + return null; + case TlcModule.DataKind.Bool: + return (bool)fieldValue ? "true" : "false"; + case TlcModule.DataKind.Enum: + return generatedClasses.GetApiName(fieldType, rootNameSpace) + "." + fieldValue; + case TlcModule.DataKind.Char: + return $"'{GetCharAsString((char)fieldValue)}'"; + case TlcModule.DataKind.Component: + var type = fieldValue.GetType(); + ModuleCatalog.ComponentInfo componentInfo; + if (!catalog.TryFindComponent(fieldType, type, out componentInfo)) + return null; + object defaultComponent = null; + try + { + defaultComponent = Activator.CreateInstance(componentInfo.ArgumentType); + } + catch (MissingMethodException) + { + // No parameterless constructor, ignore. + } + var propertyBag = new List(); + if (defaultComponent != null) + { + foreach (var fieldInfo in componentInfo.ArgumentType.GetFields()) + { + var inputAttr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() as ArgumentAttribute; + if (inputAttr == null || inputAttr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) + continue; + if (fieldInfo.FieldType == typeof(JArray) || fieldInfo.FieldType == typeof(JObject)) + continue; + + var propertyValue = GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(fieldValue), generatedClasses, rootNameSpace); + var defaultPropertyValue = GetValue(catalog, fieldInfo.FieldType, fieldInfo.GetValue(defaultComponent), generatedClasses, rootNameSpace); + if (propertyValue != defaultPropertyValue) + propertyBag.Add($"{Capitalize(inputAttr.Name ?? fieldInfo.Name)} = {propertyValue}"); + } + } + var properties = propertyBag.Count > 0 ? $" {{ {string.Join(", ", propertyBag)} }}" : ""; + return $"new {GetComponentName(componentInfo)}(){properties}"; + case TlcModule.DataKind.Unknown: + return $"new {generatedClasses.GetApiName(fieldType, rootNameSpace)}()"; + default: + return fieldValue.ToString(); + } + } + + private static string Quote(string src) + { + var dst = src.Replace("\\", @"\\").Replace("\"", "\\\"").Replace("\n", @"\n").Replace("\r", @"\r"); + return "\"" + dst + "\""; + } + + public static string GetComponentName(ModuleCatalog.ComponentInfo component) + { + return $"{Capitalize(component.Name)}{component.Kind}"; + } + + public static void GenerateSummary(IndentingTextWriter writer, string summary) + { + if (string.IsNullOrEmpty(summary)) + return; + writer.WriteLine("/// "); + foreach (var line in summary.Split(new[] { Environment.NewLine }, StringSplitOptions.RemoveEmptyEntries)) + writer.WriteLine($"/// {line}"); + writer.WriteLine("/// "); + } + + public static void GenerateHeader(IndentingTextWriter writer) + { + writer.WriteLine("//------------------------------------------------------------------------------"); + writer.WriteLine("// "); + writer.WriteLine("// This code was generated by a tool."); + writer.WriteLine("//"); + writer.WriteLine("// Changes to this file may cause incorrect behavior and will be lost if"); + writer.WriteLine("// the code is regenerated."); + writer.WriteLine("// "); + writer.WriteLine("//------------------------------------------------------------------------------"); + writer.WriteLine("#pragma warning disable"); + writer.WriteLine("using System.Collections.Generic;"); + writer.WriteLine("using Microsoft.ML.Runtime;"); + writer.WriteLine("using Microsoft.ML.Runtime.Data;"); + writer.WriteLine("using Microsoft.ML.Runtime.EntryPoints;"); + writer.WriteLine("using Newtonsoft.Json;"); + writer.WriteLine("using System;"); + writer.WriteLine("using System.Linq;"); + writer.WriteLine("using Microsoft.ML.Runtime.CommandLine;"); + writer.WriteLine(); + writer.WriteLine("namespace Microsoft.ML"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("namespace Runtime"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("public sealed partial class Experiment"); + writer.WriteLine("{"); + writer.Indent(); + } + + public static void GenerateFooter(IndentingTextWriter writer) + { + writer.Outdent(); + writer.WriteLine("}"); + } + + public static void GenerateMethod(IndentingTextWriter writer, string className, string defaultNamespace) + { + var inputOuputClassName = defaultNamespace + className; + writer.WriteLine($"public {inputOuputClassName}.Output Add({inputOuputClassName} input)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"var output = new {inputOuputClassName}.Output();"); + writer.WriteLine("Add(input, output);"); + writer.WriteLine("return output;"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine($"public void Add({inputOuputClassName} input, {inputOuputClassName}.Output output)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"_jsonNodes.Add(Serialize(\"{className}\", input, output));"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(); + } + + public static void GenerateLoaderAddInputMethod(IndentingTextWriter writer, string className) + { + //Constructor. + writer.WriteLine("[JsonIgnore]"); + writer.WriteLine("private string _inputFilePath = null;"); + writer.WriteLine($"public {className}(string filePath)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("_inputFilePath = filePath;"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + + //SetInput. + writer.WriteLine($"public void SetInput(IHostEnvironment env, Experiment experiment)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("IFileHandle inputFile = new SimpleFileHandle(env, _inputFilePath, false, false);"); + writer.WriteLine("experiment.SetInput(InputFile, inputFile);"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + + //GetInputData + writer.WriteLine("public Var GetInputData() => null;"); + writer.WriteLine(""); + + //Apply. + writer.WriteLine($"public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("Contracts.Assert(previousStep == null);"); + writer.WriteLine(""); + writer.WriteLine($"return new {className}PipelineStep(experiment.Add(this));"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(""); + + //Pipelinestep class. + writer.WriteLine($"private class {className}PipelineStep : ILearningPipelineDataStep"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine($"public {className}PipelineStep (Output output)"); + writer.WriteLine("{"); + writer.Indent(); + writer.WriteLine("Data = output.Data;"); + writer.WriteLine("Model = null;"); + writer.Outdent(); + writer.WriteLine("}"); + writer.WriteLine(); + writer.WriteLine("public Var Data { get; }"); + writer.WriteLine("public Var Model { get; }"); + writer.Outdent(); + writer.WriteLine("}"); + } + } +} diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/GeneratedClasses.cs b/src/Microsoft.ML/Runtime/Internal/Tools/GeneratedClasses.cs new file mode 100644 index 0000000000..fe8adf35fc --- /dev/null +++ b/src/Microsoft.ML/Runtime/Internal/Tools/GeneratedClasses.cs @@ -0,0 +1,102 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Runtime.Internal.Tools +{ + internal sealed class GeneratedClasses + { + private sealed class ApiClass + { + public string OriginalName { get; set; } + public string NewName { get; set; } + public bool Generated { get; set; } + } + + private readonly Dictionary _typesSymbolTable; + + public GeneratedClasses() + { + _typesSymbolTable = new Dictionary(); + } + + public string GetApiName(Type type, string rootNamespace) + { + string apiName = ""; + if (!_typesSymbolTable.TryGetValue(type.FullName, out ApiClass apiClass)) + apiName = GenerateIntenalName(type, rootNamespace); + else + apiName = apiClass.NewName; + + if (!string.IsNullOrEmpty(rootNamespace)&& apiName.StartsWith(rootNamespace)) + return apiName.Substring(rootNamespace.Length + 1); + else return apiName; + } + + private string GenerateIntenalName(Type type, string currentNamespace) + { + var fullTypeName = type.FullName; + string name = currentNamespace != "" ? currentNamespace + '.' : ""; + + int bracketIndex = fullTypeName.IndexOf('['); + Type[] genericTypes = null; + if (type.IsGenericType) + genericTypes = type.GetGenericArguments(); + if (bracketIndex > 0) + { + Contracts.AssertValue(genericTypes); + fullTypeName = fullTypeName.Substring(0, bracketIndex); + } + + // When the type is nested, the names of the outer types are concatenated with a '+'. + var nestedNames = fullTypeName.Split('+'); + var baseName = nestedNames[0]; + + // We currently only handle generic types in the outer most class, support for generic inner classes + // can be added if needed. + int backTickIndex = baseName.LastIndexOf('`'); + int dotIndex = baseName.LastIndexOf('.'); + Contracts.Assert(dotIndex >= 0); + if (backTickIndex < 0) + name += baseName.Substring(dotIndex + 1); + else + { + name += baseName.Substring(dotIndex + 1, backTickIndex - dotIndex - 1); + Contracts.AssertValue(genericTypes); + if (genericTypes != null) + { + foreach (var genType in genericTypes) + { + var splitNames = genType.FullName.Split('+'); + if (splitNames[0].LastIndexOf('.') >= 0) + splitNames[0] = splitNames[0].Substring(splitNames[0].LastIndexOf('.') + 1); + name += string.Join("", splitNames); + } + } + } + + for (int i = 1; i < nestedNames.Length; i++) + name += nestedNames[i]; + + Contracts.Assert(_typesSymbolTable.Values.All(apiclass => string.Compare(apiclass.NewName, name) != 0)); + _typesSymbolTable[type.FullName] = new ApiClass { OriginalName = type.FullName, Generated = false, NewName = name }; + return name; + } + + internal bool IsGenerated(string fullName) + { + if (!_typesSymbolTable.ContainsKey(fullName)) + return false; + return _typesSymbolTable[fullName].Generated; + } + + internal void MarkAsGenerated(string fullName) + { + _typesSymbolTable[fullName].Generated = true; + } + } +} diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index b42dee2d52..aa4abe1752 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -303,14 +303,14 @@ public void TestCrossValidationMacro() { Name = "Label", Source = new [] { new TextLoaderRange(11) }, - Type = DataKind.Num + Type = ML.Data.DataKind.Num }, new TextLoaderColumn() { Name = "Features", Source = new [] { new TextLoaderRange(0,10) }, - Type = DataKind.Num + Type = ML.Data.DataKind.Num } } } @@ -666,7 +666,7 @@ public void TestCrossValidationMacroWithNonDefaultNames() importInput.Arguments.Column = new TextLoaderColumn[] { new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, - new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = DataKind.Text }, + new TextLoaderColumn { Name = "Workclass", Source = new[] { new TextLoaderRange(1) }, Type = ML.Data.DataKind.Text }, new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(9, 14) } } }; var importOutput = experiment.Add(importInput); diff --git a/test/Microsoft.ML.TestFramework/ModelHelper.cs b/test/Microsoft.ML.TestFramework/ModelHelper.cs index edf4408bcb..42c684e51a 100644 --- a/test/Microsoft.ML.TestFramework/ModelHelper.cs +++ b/test/Microsoft.ML.TestFramework/ModelHelper.cs @@ -70,147 +70,147 @@ private static ITransformModel CreateKcHousePricePredictorModel(string dataPath) { Name = "Id", Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Text + Type = Data.DataKind.Text }, new TextLoaderColumn() { Name = "Date", Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text + Type = Data.DataKind.Text }, new TextLoaderColumn() { Name = "Label", Source = new [] { new TextLoaderRange(2) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Bedrooms", Source = new [] { new TextLoaderRange(3) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Bathrooms", Source = new [] { new TextLoaderRange(4) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SqftLiving", Source = new [] { new TextLoaderRange(5) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SqftLot", Source = new [] { new TextLoaderRange(6) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Floors", Source = new [] { new TextLoaderRange(7) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Waterfront", Source = new [] { new TextLoaderRange(8) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "View", Source = new [] { new TextLoaderRange(9) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Condition", Source = new [] { new TextLoaderRange(10) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Grade", Source = new [] { new TextLoaderRange(11) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SqftAbove", Source = new [] { new TextLoaderRange(12) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SqftBasement", Source = new [] { new TextLoaderRange(13) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "YearBuilt", Source = new [] { new TextLoaderRange(14) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "YearRenovated", Source = new [] { new TextLoaderRange(15) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Zipcode", Source = new [] { new TextLoaderRange(16) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Lat", Source = new [] { new TextLoaderRange(17) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Long", Source = new [] { new TextLoaderRange(18) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SqftLiving15", Source = new [] { new TextLoaderRange(19) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SqftLot15", Source = new [] { new TextLoaderRange(20) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, } } diff --git a/test/Microsoft.ML.Tests/CSharpCodeGen.cs b/test/Microsoft.ML.Tests/CSharpCodeGen.cs index c647110702..678edac461 100644 --- a/test/Microsoft.ML.Tests/CSharpCodeGen.cs +++ b/test/Microsoft.ML.Tests/CSharpCodeGen.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using System.IO; using Xunit; @@ -9,17 +10,39 @@ namespace Microsoft.ML.Tests { - public class CSharpCodeGen : BaseTestClass + public class CSharpCodeGen : BaseTestBaseline { public CSharpCodeGen(ITestOutputHelper output) : base(output) { } - [Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] - public void GenerateCSharpAPI() + [Fact(Skip = "Execute this test if you want to regenerate CSharpApi file")] + public void RegenerateCSharpApi() { - var cSharpAPIPath = Path.Combine(RootDir, @"src\\Microsoft.ML\\CSharpApi.cs"); - Runtime.Tools.Maml.Main(new[] { $"? generator=cs{{csFilename={cSharpAPIPath}}}" }); + var basePath = GetDataPath("../../src/Microsoft.ML/CSharpApi.cs"); + Runtime.Tools.Maml.Main(new[] { $"? generator=cs{{csFilename={basePath}}}" }); + } + + [Fact] + public void TestGeneratedCSharpAPI() + { + var dataPath = GetOutputPath("Api.cs"); + Runtime.Tools.Maml.Main(new[] { $"? generator=cs{{csFilename={dataPath}}}" }); + + var basePath = GetDataPath("../../src/Microsoft.ML/CSharpApi.cs"); + using (StreamReader baseline = OpenReader(basePath)) + using (StreamReader result = OpenReader(dataPath)) + { + for (; ; ) + { + string line1 = baseline.ReadLine(); + string line2 = result.ReadLine(); + + if (line1 == null && line2 == null) + break; + Assert.Equal(line1, line2); + } + } } } } diff --git a/test/Microsoft.ML.Tests/OnnxTests.cs b/test/Microsoft.ML.Tests/OnnxTests.cs index 6910aba70b..477a9c6fa6 100644 --- a/test/Microsoft.ML.Tests/OnnxTests.cs +++ b/test/Microsoft.ML.Tests/OnnxTests.cs @@ -52,14 +52,14 @@ public void BinaryClassificationSaveModelToOnnxTest() { Name = "Label", Source = new [] { new TextLoaderRange(0) }, - Type = DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "Features", Source = new [] { new TextLoaderRange(1, 9) }, - Type = DataKind.Num + Type = Data.DataKind.Num } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 7a31f17d96..1ebc2489ec 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -211,14 +211,14 @@ private LearningPipeline PreparePipeline() { Name = "Label", Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SentimentText", Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text + Type = Data.DataKind.Text } } } @@ -265,14 +265,14 @@ private Data.TextLoader PrepareTextLoaderTestData() { Name = "Label", Source = new [] { new TextLoaderRange(0) }, - Type = Runtime.Data.DataKind.Num + Type = Data.DataKind.Num }, new TextLoaderColumn() { Name = "SentimentText", Source = new [] { new TextLoaderRange(1) }, - Type = Runtime.Data.DataKind.Text + Type = Data.DataKind.Text } } }