Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoML] Minor changes to generated project in CLI based on feedback #3371

Merged
merged 3 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ internal class CodeGenerator : IProjectGenerator
private readonly Pipeline pipeline;
private readonly CodeGeneratorSettings settings;
private readonly ColumnInferenceResults columnInferenceResult;
private readonly HashSet<string> LightGBMTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() };
private readonly HashSet<string> mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };

internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings)
{
Expand All @@ -29,25 +31,32 @@ internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInference

public void GenerateOutput()
{
// Get the extra nuget packages to be included in the generated project.
var trainerNodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer);

bool includeLightGbmPackage = false;
bool includeMklComponentsPackage = false;
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage);

// Get Namespace
var namespaceValue = Utils.Normalize(settings.OutputName);
var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind;
Type labelTypeCsharp = Utils.GetCSharpType(labelType);

// Generate Model Project
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp);
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);

// Write files to disk.
var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model");
var dataModelsDir = Path.Combine(modelprojectDir, "DataModels");
var modelProjectName = $"{settings.OutputName}.Model.csproj";

Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "Observation.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "Prediction.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "SampleObservation.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "SamplePrediction.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);

// Generate ConsoleApp Project
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp);
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);

// Write files to disk.
var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp");
Expand All @@ -65,12 +74,33 @@ public void GenerateOutput()
Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
}

internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp)
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage)
{
foreach (var node in trainerNodes)
{
PipelineNode currentNode = node;
if (currentNode.Name == TrainerName.Ova.ToString())
{
currentNode = (PipelineNode)currentNode.Properties["BinaryTrainer"];
}

if (LightGBMTrainers.Contains(currentNode.Name))
{
includeLightGbmPackage = true;
}
else if (mklComponentsTrainers.Contains(currentNode.Name))
{
includeMklComponentsPackage = true;
}
}
}

internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
{
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);

var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, true, true);
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage);

var transformsAndTrainers = GenerateTransformsAndTrainers();
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
Expand All @@ -79,14 +109,14 @@ public void GenerateOutput()
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
}

internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp)
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
{
var classLabels = this.GenerateClassLabels();
var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels);
observationCSFileContent = Utils.FormatCode(observationCSFileContent);
var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue);
predictionCSFileContent = Utils.FormatCode(predictionCSFileContent);
var modelProjectFileContent = GenerateModelProjectFileContent();
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage);
return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent);
}

Expand Down Expand Up @@ -218,9 +248,9 @@ internal IList<string> GenerateClassLabels()
}

#region Model project
private static string GenerateModelProjectFileContent()
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage)
{
ModelProject modelProject = new ModelProject();
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage };
return modelProject.TransformText();
}

Expand All @@ -238,9 +268,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
#endregion

#region Predict Project
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeMklComponentsPackage, bool includeLightGBMPackage)
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage)
{
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGBMPackage };
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage };
return predictProjectFileContent.TransformText();
}

Expand Down Expand Up @@ -290,6 +320,5 @@ private string GenerateModelBuilderCSFileContent(string usings,
return modelBuilder.TransformText();
}
#endregion

}
}
4 changes: 1 addition & 3 deletions src/mlnet/Templates/Console/ModelBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@ public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDat
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($""=============== Saving the model ==============="");
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(mlModel, modelInputSchema, fs);

mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
Console.WriteLine(""The model is saved to {0}"", GetAbsolutePath(modelRelativePath));
}

Expand Down
4 changes: 1 addition & 3 deletions src/mlnet/Templates/Console/ModelBuilder.tt
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ else{#>
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model ===============");
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(mlModel, modelInputSchema, fs);

mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}

Expand Down
67 changes: 44 additions & 23 deletions src/mlnet/Templates/Console/ModelProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console
/// Class to produce the template output
/// </summary>

#line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ModelProject.tt"
#line 1 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")]
public partial class ModelProject : ModelProjectBase
{
Expand All @@ -28,30 +28,51 @@ public partial class ModelProject : ModelProjectBase
/// </summary>
public virtual string TransformText()
{
this.Write(@"<Project Sdk=""Microsoft.NET.Sdk"">

<PropertyGroup>
<TargetFramework>netcoreapp2.1</TargetFramework>
</PropertyGroup>
<PropertyGroup>
<RestoreSources>
https://api.nuget.org/v3/index.json;
</RestoreSources>
</PropertyGroup>
<ItemGroup>
<PackageReference Include=""Microsoft.ML"" Version=""1.0.0-preview"" />
</ItemGroup>

<ItemGroup>
<None Update=""MLModel.zip"">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project>
");
this.Write("<Project Sdk=\"Microsoft.NET.Sdk\">\r\n\r\n <PropertyGroup>\r\n <TargetFramework>netc" +
"oreapp2.1</TargetFramework>\r\n </PropertyGroup>\r\n <ItemGroup>\r\n <PackageRefe" +
"rence Include=\"Microsoft.ML\" Version=\"1.0.0-preview\" />\r\n");

#line 13 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
if(IncludeLightGBMPackage){

#line default
#line hidden
this.Write(" <PackageReference Include=\"Microsoft.ML.LightGBM\" Version=\"1.0.0-preview\" />\r" +
"\n");

#line 15 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
}

#line default
#line hidden

#line 16 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
if(IncludeMklComponentsPackage){

#line default
#line hidden
this.Write(" <PackageReference Include=\"Microsoft.ML.Mkl.Components\" Version=\"1.0.0-previe" +
"w\" />\r\n");

#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
}

#line default
#line hidden
this.Write(" </ItemGroup>\r\n\r\n <ItemGroup>\r\n <None Update=\"MLModel.zip\">\r\n <CopyToOu" +
"tputDirectory>PreserveNewest</CopyToOutputDirectory>\r\n </None>\r\n </ItemGroup" +
">\r\n \r\n</Project>\r\n");
return this.GenerationEnvironment.ToString();
}

#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"

public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}


#line default
#line hidden
}

#line default
Expand Down
15 changes: 10 additions & 5 deletions src/mlnet/Templates/Console/ModelProject.tt
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

<PropertyGroup>
<TargetFramework>netcoreapp2.1</TargetFramework>
</PropertyGroup>
<PropertyGroup>
<RestoreSources>
https://api.nuget.org/v3/index.json;
</RestoreSources>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
<# if(IncludeLightGBMPackage){ #>
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
<#}#>
<# if(IncludeMklComponentsPackage){ #>
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<#}#>
</ItemGroup>

<ItemGroup>
Expand All @@ -24,3 +25,7 @@
</ItemGroup>

</Project>
<#+
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
#>
Loading