Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoML] Rename classes in generated project. #3634

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public void GenerateOutput()
var dataModelsDir = Path.Combine(modelprojectDir, "DataModels");
var modelProjectName = $"{settings.OutputName}.Model.csproj";

Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "SampleObservation.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "SamplePrediction.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.ModelInputCSFileContent, "ModelInput.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.ModelOutputCSFileContent, "ModelOutput.cs", dataModelsDir);
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);

// Generate ConsoleApp Project
Expand Down Expand Up @@ -116,15 +116,15 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
}

internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
internal (string ModelInputCSFileContent, string ModelOutputCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var classLabels = this.GenerateClassLabels();
var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels);
observationCSFileContent = Utils.FormatCode(observationCSFileContent);
var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue);
predictionCSFileContent = Utils.FormatCode(predictionCSFileContent);
var modelInputCSFileContent = GenerateModelInputCSFileContent(namespaceValue, classLabels);
modelInputCSFileContent = Utils.FormatCode(modelInputCSFileContent);
var modelOutputCSFileContent = GenerateModelOutputCSFileContent(labelTypeCsharp.Name, namespaceValue);
modelOutputCSFileContent = Utils.FormatCode(modelOutputCSFileContent);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);
return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent);
return (modelInputCSFileContent, modelOutputCSFileContent, modelProjectFileContent);
}

internal (string Usings, string TrainerMethod, List<string> PreTrainerTransforms, List<string> PostTrainerTransforms) GenerateTransformsAndTrainers()
Expand Down Expand Up @@ -261,16 +261,16 @@ private static string GenerateModelProjectFileContent(bool includeLightGbmPackag
return modelProject.TransformText();
}

private string GeneratePredictionCSFileContent(string predictionLabelType, string namespaceValue)
private string GenerateModelOutputCSFileContent(string predictionLabelType, string namespaceValue)
{
PredictionClass predictionClass = new PredictionClass() { TaskType = settings.MlTask.ToString(), PredictionLabelType = predictionLabelType, Namespace = namespaceValue };
return predictionClass.TransformText();
ModelOutputClass modelOutputClass = new ModelOutputClass() { TaskType = settings.MlTask.ToString(), PredictionLabelType = predictionLabelType, Namespace = namespaceValue };
return modelOutputClass.TransformText();
}

private string GenerateObservationCSFileContent(string namespaceValue, IList<string> classLabels)
private string GenerateModelInputCSFileContent(string namespaceValue, IList<string> classLabels)
{
ObservationClass observationClass = new ObservationClass() { Namespace = namespaceValue, ClassLabels = classLabels };
return observationClass.TransformText();
ModelInputClass modelInputClass = new ModelInputClass() { Namespace = namespaceValue, ClassLabels = classLabels };
return modelInputClass.TransformText();
}
#endregion

Expand Down
8 changes: 4 additions & 4 deletions src/mlnet/Templates/Console/ModelBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public virtual string TransformText()
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader : ");
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
Expand All @@ -77,9 +77,9 @@ public static void CreateModel()
this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
this.Write(");\r\n\r\n");
if(!string.IsNullOrEmpty(TestPath)){
this.Write(" IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObserv" +
"ation>(\r\n path: TEST_DATA_FILEPATH,\r\n" +
" hasHeader : ");
this.Write(" IDataView testDataView = mlContext.Data.LoadFromTextFile<ModelInput>(" +
"\r\n path: TEST_DATA_FILEPATH,\r\n " +
" hasHeader : ");
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
this.Write(",\r\n separatorChar : \'");
this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
Expand Down
4 changes: 2 additions & 2 deletions src/mlnet/Templates/Console/ModelBuilder.tt
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ namespace <#= Namespace #>.ConsoleApp
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>,
separatorChar : '<#= Regex.Escape(Separator.ToString()) #>',
allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>,
allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>);

<# if(!string.IsNullOrEmpty(TestPath)){ #>
IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView testDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TEST_DATA_FILEPATH,
hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>,
separatorChar : '<#= Regex.Escape(Separator.ToString()) #>',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console
/// Class to produce the template output
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")]
public partial class ObservationClass : ObservationClassBase
public partial class ModelInputClass : ModelInputClassBase
{
/// <summary>
/// Create the template output
Expand All @@ -35,7 +35,7 @@ public virtual string TransformText()

namespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
this.Write(".Model.DataModels\r\n{\r\n public class SampleObservation\r\n {\r\n");
this.Write(".Model.DataModels\r\n{\r\n public class ModelInput\r\n {\r\n");
foreach(var label in ClassLabels){
this.Write(" ");
this.Write(this.ToStringHelper.ToStringWithCulture(label));
Expand All @@ -54,7 +54,7 @@ namespace ");
/// Base class for this transformation
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")]
public class ObservationClassBase
public class ModelInputClassBase
{
#region Fields
private global::System.Text.StringBuilder generationEnvironmentField;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Microsoft.ML.Data;

namespace <#= Namespace #>.Model.DataModels
{
public class SampleObservation
public class ModelInput
{
<#foreach(var label in ClassLabels){#>
<#=label#>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace Microsoft.ML.CLI.Templates.Console
/// Class to produce the template output
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")]
public partial class PredictionClass : PredictionClassBase
public partial class ModelOutputClass : ModelOutputClassBase
{
/// <summary>
/// Create the template output
Expand All @@ -36,7 +36,7 @@ public virtual string TransformText()

namespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
this.Write(".Model.DataModels\r\n{\r\n public class SamplePrediction\r\n {\r\n");
this.Write(".Model.DataModels\r\n{\r\n public class ModelOutput\r\n {\r\n");
if("BinaryClassification".Equals(TaskType)){
this.Write(" // ColumnName attribute is used to change the column name from\r\n /" +
"/ its default value, which is the name of the field.\r\n [ColumnName(\"Predi" +
Expand Down Expand Up @@ -67,7 +67,7 @@ namespace ");
/// Base class for this transformation
/// </summary>
[global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")]
public class PredictionClassBase
public class ModelOutputClassBase
{
#region Fields
private global::System.Text.StringBuilder generationEnvironmentField;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using Microsoft.ML.Data;

namespace <#= Namespace #>.Model.DataModels
{
public class SamplePrediction
public class ModelOutput
{
<#if("BinaryClassification".Equals(TaskType)){ #>
// ColumnName attribute is used to change the column name from
Expand Down
14 changes: 7 additions & 7 deletions src/mlnet/Templates/Console/PredictProgram.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ static void Main(string[] args)
//ModelBuilder.CreateModel();

ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Create sample data to do a single prediction with it
SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH);
ModelInput sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH);

// Try a single prediction
SamplePrediction predictionResult = predEngine.Predict(sampleData);
ModelOutput predictionResult = predEngine.Predict(sampleData);

");
if("BinaryClassification".Equals(TaskType)){
Expand All @@ -92,10 +92,10 @@ static void Main(string[] args)

// Method to load single row of data to try a single prediction
// You can change this code and create your own sample data here (Hardcoded or from any source)
private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath)
private static ModelInput CreateSingleDataSample(MLContext mlContext, string dataFilePath)
{
// Read dataset to get a single row for trying a prediction
IDataView dataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView dataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: dataFilePath,
hasHeader : ");
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
Expand All @@ -107,8 +107,8 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
this.Write(@");

// Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable<SampleObservation>(dataView, false)
// Here (ModelInput object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable<ModelInput>(dataView, false)
.First();
return sampleForPrediction;
}
Expand Down
14 changes: 7 additions & 7 deletions src/mlnet/Templates/Console/PredictProgram.tt
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ namespace <#= Namespace #>.ConsoleApp
//ModelBuilder.CreateModel();

ITransformer mlModel = mlContext.Model.Load(GetAbsolutePath(MODEL_FILEPATH), out DataViewSchema inputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlModel);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(mlModel);

// Create sample data to do a single prediction with it
SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH);
ModelInput sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH);

// Try a single prediction
SamplePrediction predictionResult = predEngine.Predict(sampleData);
ModelOutput predictionResult = predEngine.Predict(sampleData);

<#if("BinaryClassification".Equals(TaskType)){ #>
Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Prediction}");
Expand All @@ -62,18 +62,18 @@ namespace <#= Namespace #>.ConsoleApp

// Method to load single row of data to try a single prediction
// You can change this code and create your own sample data here (Hardcoded or from any source)
private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath)
private static ModelInput CreateSingleDataSample(MLContext mlContext, string dataFilePath)
{
// Read dataset to get a single row for trying a prediction
IDataView dataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView dataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: dataFilePath,
hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>,
separatorChar : '<#= Regex.Escape(Separator.ToString()) #>',
allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>,
allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>);

// Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable<SampleObservation>(dataView, false)
// Here (ModelInput object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
ModelInput sampleForPrediction = mlContext.Data.CreateEnumerable<ModelInput>(dataView, false)
.First();
return sampleForPrediction;
}
Expand Down
16 changes: 8 additions & 8 deletions src/mlnet/mlnet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@
<AutoGen>True</AutoGen>
<DependentUpon>ModelProject.tt</DependentUpon>
</Compile>
<Compile Update="Templates\Console\ObservationClass.cs">
<Compile Update="Templates\Console\ModelInputClass.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>ObservationClass.tt</DependentUpon>
<DependentUpon>ModelInputClass.tt</DependentUpon>
</Compile>
<Compile Update="Templates\Console\PredictionClass.cs">
<Compile Update="Templates\Console\ModelOutputClass.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>PredictionClass.tt</DependentUpon>
<DependentUpon>ModelOutputClass.tt</DependentUpon>
</Compile>
<Compile Update="Templates\Console\PredictProgram.cs">
<DesignTime>True</DesignTime>
Expand Down Expand Up @@ -90,13 +90,13 @@
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>ModelProject.cs</LastGenOutput>
</None>
<None Update="Templates\Console\ObservationClass.tt">
<None Update="Templates\Console\ModelInputClass.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>ObservationClass.cs</LastGenOutput>
<LastGenOutput>ModelInputClass.cs</LastGenOutput>
</None>
<None Update="Templates\Console\PredictionClass.tt">
<None Update="Templates\Console\ModelOutputClass.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
<LastGenOutput>PredictionClass.cs</LastGenOutput>
<LastGenOutput>ModelOutputClass.cs</LastGenOutput>
</None>
<None Update="Templates\Console\PredictProgram.tt">
<Generator>TextTemplatingFilePreprocessor</Generator>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ namespace TestNamespace.ConsoleApp
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: true);

IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView testDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TEST_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ namespace TestNamespace.ConsoleApp
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
allowQuoting: true,
allowSparse: true);

IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
IDataView testDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TEST_DATA_FILEPATH,
hasHeader: true,
separatorChar: ',',
Expand Down
Loading