diff --git a/src/Microsoft.ML/LearningPipeline.cs b/src/Microsoft.ML/LearningPipeline.cs index 02a13b702a..0637c5e65e 100644 --- a/src/Microsoft.ML/LearningPipeline.cs +++ b/src/Microsoft.ML/LearningPipeline.cs @@ -66,7 +66,7 @@ public LearningPipeline() /// /// Specify seed for random generator /// Specify concurrency factor (default value - autoselection) - internal LearningPipeline(int? seed=null, int conc=0) + internal LearningPipeline(int? seed = null, int conc = 0) { _seed = seed; _conc = conc; @@ -109,6 +109,16 @@ internal LearningPipeline(int? seed=null, int conc=0) /// Any ML component (data loader, transform or trainer) defined as . public void Add(ILearningPipelineItem item) => Items.Add(item); + /// + /// Add a data loader, transform or trainer into the pipeline. + /// + /// Any ML component (data loader, transform or trainer) defined as . + /// Pipeline with added item + public LearningPipeline Append(ILearningPipelineItem item) + { + Add(item); + return this; + } /// /// Remove all the loaders/transforms/trainers from the pipeline. /// @@ -152,7 +162,7 @@ public PredictionModel Train() where TInput : class where TOutput : class, new() { - using (var environment = new TlcEnvironment(seed:_seed, conc:_conc)) + using (var environment = new TlcEnvironment(seed: _seed, conc: _conc)) { Experiment experiment = environment.CreateExperiment(); ILearningPipelineStep step = null; diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index 63a8db17ee..165b8d8fd2 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -167,5 +167,22 @@ public void NullableBooleanLabelPipeline() pipeline.Add(new FastForestBinaryClassifier()); var model = pipeline.Train(); } + + [Fact] + public void AppendPipeline() + { + var pipeline = new LearningPipeline(); + pipeline.Append(new CategoricalOneHotVectorizer("String1", "String2")) + .Append(new ColumnConcatenator(outputColumn: "Features", "String1", "String2", "Number1", "Number2")) + .Append(new StochasticDualCoordinateAscentRegressor()); + Assert.NotNull(pipeline); + Assert.Equal(3, pipeline.Count); + + pipeline.Remove(pipeline.ElementAt(2)); + Assert.Equal(2, pipeline.Count); + + pipeline.Append(new StochasticDualCoordinateAscentRegressor()); + Assert.Equal(3, pipeline.Count); + } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 5dcbf3a588..6612dfea69 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -18,7 +18,7 @@ public void TrainAndPredictIrisModelTest() { string dataPath = GetDataPath("iris.txt"); - var pipeline = new LearningPipeline(); + var pipeline = new LearningPipeline(seed:1, conc:1); pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false)); pipeline.Add(new ColumnConcatenator(outputColumn: "Features",