forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
44 changed files
with
760 additions
and
828 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
namespace Microsoft.ML.Auto | ||
{ | ||
public class AutoInferenceCatalog | ||
{ | ||
private readonly MLContext _context; | ||
|
||
internal AutoInferenceCatalog(MLContext context) | ||
{ | ||
_context = context; | ||
} | ||
|
||
public RegressionExperiment CreateRegressionExperiment(uint maxInferenceTimeInSeconds) | ||
{ | ||
return new RegressionExperiment(_context, new RegressionExperimentSettings() | ||
{ | ||
MaxInferenceTimeInSeconds = maxInferenceTimeInSeconds | ||
}); | ||
} | ||
|
||
public RegressionExperiment CreateRegressionExperiment(RegressionExperimentSettings experimentSettings) | ||
{ | ||
return new RegressionExperiment(_context, experimentSettings); | ||
} | ||
|
||
public BinaryClassificationExperiment CreateBinaryClassificationExperiment(uint maxInferenceTimeInSeconds) | ||
{ | ||
return new BinaryClassificationExperiment(_context, new BinaryExperimentSettings() | ||
{ | ||
MaxInferenceTimeInSeconds = maxInferenceTimeInSeconds | ||
}); | ||
} | ||
|
||
public BinaryClassificationExperiment CreateBinaryClassificationExperiment(BinaryExperimentSettings experimentSettings) | ||
{ | ||
return new BinaryClassificationExperiment(_context, experimentSettings); | ||
} | ||
|
||
public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(uint maxInferenceTimeInSeconds) | ||
{ | ||
return new MulticlassClassificationExperiment(_context, new MulticlassExperimentSettings() | ||
{ | ||
MaxInferenceTimeInSeconds = maxInferenceTimeInSeconds | ||
}); | ||
} | ||
|
||
public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(MulticlassExperimentSettings experimentSettings) | ||
{ | ||
return new MulticlassClassificationExperiment(_context, experimentSettings); | ||
} | ||
|
||
public ColumnInferenceResults InferColumns(string path, string label,char? separatorChar = null, bool? allowQuotedStrings = null, | ||
bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true) | ||
{ | ||
//UserInputValidationUtil.ValidateInferColumnsArgs(path, label); | ||
return ColumnInferenceApi.InferColumns(_context, path, label, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns); | ||
} | ||
|
||
public ColumnInferenceResults InferColumns(string path, uint labelColumnIndex, bool hasHeader = false, char? separatorChar = null, | ||
bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true) | ||
{ | ||
//UserInputValidationUtil.ValidateInferColumnsArgs(path); | ||
return ColumnInferenceApi.InferColumns(_context, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns); | ||
} | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
// Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.Auto | ||
{ | ||
public class BinaryExperimentSettings : ExperimentSettings | ||
{ | ||
public IProgress<RunResult<BinaryClassificationMetrics>> ProgressCallback; | ||
public BinaryClassificationMetric OptimizingMetric; | ||
public BinaryClassificationTrainer[] WhitelistedTrainers; | ||
} | ||
|
||
public enum BinaryClassificationMetric | ||
{ | ||
Accuracy | ||
} | ||
|
||
public enum BinaryClassificationTrainer | ||
{ | ||
LightGbm | ||
} | ||
|
||
public class BinaryClassificationExperiment | ||
{ | ||
private readonly MLContext _context; | ||
private readonly BinaryExperimentSettings _settings; | ||
|
||
internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings) | ||
{ | ||
_context = context; | ||
_settings = settings; | ||
} | ||
|
||
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null) | ||
{ | ||
return Execute(_context, trainData, columnInformation, null, preFeaturizers); | ||
} | ||
|
||
public IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null) | ||
{ | ||
return Execute(_context, trainData, columnInformation, validationData, preFeaturizers); | ||
} | ||
|
||
internal RunResult<BinaryClassificationMetrics> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null) | ||
{ | ||
throw new NotImplementedException(); | ||
} | ||
|
||
internal IEnumerable<RunResult<BinaryClassificationMetrics>> Execute(MLContext context, | ||
IDataView trainData, | ||
ColumnInformation columnInfo, | ||
IDataView validationData = null, | ||
IEstimator<ITransformer> preFeaturizers = null) | ||
{ | ||
columnInfo = columnInfo ?? new ColumnInformation(); | ||
//UserInputValidationUtil.ValidateAutoFitArgs(trainData, labelColunName, validationData, settings, columnPurposes) | ||
|
||
// run autofit & get all pipelines run in that process | ||
var autoFitter = new AutoFitter<BinaryClassificationMetrics>(context, TaskKind.BinaryClassification, trainData, columnInfo, | ||
validationData, preFeaturizers, OptimizingMetric.Accuracy, _settings?.ProgressCallback, | ||
_settings); | ||
|
||
return autoFitter.Fit(); | ||
} | ||
} | ||
|
||
public static class BinaryExperimentResultExtensions | ||
{ | ||
public static RunResult<BinaryClassificationMetrics> Best(this IEnumerable<RunResult<BinaryClassificationMetrics>> results) | ||
{ | ||
double maxScore = results.Select(r => r.Metrics.Accuracy).Max(); | ||
return results.First(r => r.Metrics.Accuracy == maxScore); | ||
} | ||
} | ||
} |
87 changes: 0 additions & 87 deletions
87
src/Microsoft.ML.Auto/API/BinaryClassificationExtension.cs
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System.Collections.Generic; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.Auto | ||
{ | ||
public class ColumnInferenceResults | ||
{ | ||
public TextLoader.Arguments TextLoaderArgs { get; set; } | ||
public ColumnInformation ColumnInformation { get; set; } | ||
} | ||
|
||
public class ColumnInformation | ||
{ | ||
public string LabelColumn = DefaultColumnNames.Label; | ||
public string NameColumn = DefaultColumnNames.Name; | ||
public string GroupIdColumn = DefaultColumnNames.GroupId; | ||
public string WeightColumn = DefaultColumnNames.Weight; | ||
public IEnumerable<string> CategoricalColumns { get; set; } | ||
public IEnumerable<string> NumericColumns { get; set; } | ||
public IEnumerable<string> TextColumns { get; set; } | ||
public IEnumerable<string> IgnoredColumns { get; set; } | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.