Skip to content

Commit

Permalink
Merge pull request dotnet#1 from dotnet/master
Browse files Browse the repository at this point in the history
Update to latest dotnet/master
  • Loading branch information
abgoswam authored Jul 18, 2018
2 parents 3053f3d + 0e37508 commit f6baa5b
Show file tree
Hide file tree
Showing 82 changed files with 1,257 additions and 1,325 deletions.
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public interface IMessageSource

/// <summary>
/// A <see cref="IHostEnvironment"/> that is also a channel listener can attach
/// listeners for messages, as sent through <see cref="IChannelProvider.StartPipe"/>.
/// listeners for messages, as sent through <see cref="IChannelProvider.StartPipe{TMessage}"/>.
/// </summary>
public interface IMessageDispatcher : IHostEnvironment
{
Expand Down
155 changes: 41 additions & 114 deletions src/Microsoft.ML.Core/Prediction/ITrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime
{
Expand All @@ -27,151 +26,79 @@ namespace Microsoft.ML.Runtime
public delegate void SignatureSequenceTrainer();
public delegate void SignatureMatrixRecommendingTrainer();

/// <summary>
/// Interface to provide extra information about a trainer.
/// </summary>
public interface ITrainerEx : ITrainer
{
// REVIEW: Ideally trainers should be able to communicate
// something about the type of data they are capable of being trained
// on, e.g., what ColumnKinds they want, how many of each, of what type,
// etc. This interface seems like the most natural conduit for that sort
// of extra information.

// REVIEW: Can we please have consistent naming here?
// 'Need' vs. 'Want' looks arbitrary to me, and it's grammatically more correct to
// be 'Needs' / 'Wants' anyway.

/// <summary>
/// Whether the trainer needs to see data in normalized form.
/// </summary>
bool NeedNormalization { get; }

/// <summary>
/// Whether the trainer needs calibration to produce probabilities.
/// </summary>
bool NeedCalibration { get; }

/// <summary>
/// Whether this trainer could benefit from a cached view of the data.
/// </summary>
bool WantCaching { get; }
}

public interface ITrainerHost
{
Random Rand { get; }
int Verbosity { get; }

TextWriter StdOut { get; }
TextWriter StdErr { get; }
}

// The Trainer (of Factory) can optionally implement this.
public interface IModelCombiner<TModel, TPredictor>
where TPredictor : IPredictor
{
TPredictor CombineModels(IEnumerable<TModel> models);
}

public delegate void SignatureModelCombiner(PredictionKind kind);

/// <summary>
/// Weakly typed interface for a trainer "session" that produces a predictor.
/// The base interface for a trainers. Implementors should not implement this interface directly,
/// but rather implement the more specific <see cref="ITrainer{TPredictor}"/>.
/// </summary>
public interface ITrainer
{
/// <summary>
/// Return the type of prediction task for the produced predictor.
/// Auxiliary information about the trainer in terms of its capabilities
/// and requirements.
/// </summary>
PredictionKind PredictionKind { get; }
TrainerInfo Info { get; }

/// <summary>
/// Returns the trained predictor.
/// REVIEW: Consider removing this.
/// Return the type of prediction task for the produced predictor.
/// </summary>
IPredictor CreatePredictor();
}

/// <summary>
/// Interface implemented by the MetalinearLearners base class.
/// Used to distinguish the MetaLinear Learners from the other learners
/// </summary>
public interface IMetaLinearTrainer
{

}
PredictionKind PredictionKind { get; }

public interface ITrainer<in TDataSet> : ITrainer
{
/// <summary>
/// Trains a predictor using the specified dataset.
/// Trains a predictor.
/// </summary>
/// <param name="data"> Training dataset </param>
void Train(TDataSet data);
/// <param name="context">A context containing at least the training data</param>
/// <returns>The trained predictor</returns>
/// <seealso cref="ITrainer{TPredictor}.Train(TrainContext)"/>
IPredictor Train(TrainContext context);
}

/// <summary>
/// Strongly typed generic interface for a trainer. A trainer object takes
/// supervision data and produces a predictor.
/// Strongly typed generic interface for a trainer. A trainer object takes training data
/// and produces a predictor.
/// </summary>
/// <typeparam name="TDataSet"> Type of the training dataset</typeparam>
/// <typeparam name="TPredictor"> Type of predictor produced</typeparam>
public interface ITrainer<in TDataSet, out TPredictor> : ITrainer<TDataSet>
public interface ITrainer<out TPredictor> : ITrainer
where TPredictor : IPredictor
{
/// <summary>
/// Returns the trained predictor.
/// </summary>
/// <returns>Trained predictor ready to make predictions</returns>
new TPredictor CreatePredictor();
}

/// <summary>
/// Trainers that want data to do their own validation implement this interface.
/// </summary>
public interface IValidatingTrainer<in TDataSet> : ITrainer<TDataSet>
{
/// <summary>
/// Trains a predictor using the specified dataset.
/// Trains a predictor.
/// </summary>
/// <param name="data">Training dataset</param>
/// <param name="validData">Validation dataset</param>
void Train(TDataSet data, TDataSet validData);
/// <param name="context">A context containing at least the training data</param>
/// <returns>The trained predictor</returns>
new TPredictor Train(TrainContext context);
}

public interface IIncrementalTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
public static class TrainerExtensions
{
/// <summary>
/// Trains a predictor using the specified dataset and a trained predictor.
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
/// Equivalent to calling <see cref="ITrainer.Train(TrainContext)"/>
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
/// </summary>
/// <param name="data">Training dataset</param>
/// <param name="predictor">A trained predictor</param>
void Train(TDataSet data, TPredictor predictor);
}
/// <param name="trainer">The trainer</param>
/// <param name="trainData">The training data.</param>
/// <returns>The trained predictor</returns>
public static IPredictor Train(this ITrainer trainer, RoleMappedData trainData)
=> trainer.Train(new TrainContext(trainData));

public interface IIncrementalValidatingTrainer<in TDataSet, in TPredictor> : ITrainer<TDataSet>
{
/// <summary>
/// Trains a predictor using the specified dataset and a trained predictor.
/// Convenience train extension for the case where one has only a training set with no auxiliary information.
/// Equivalent to calling <see cref="ITrainer{TPredictor}.Train(TrainContext)"/>
/// on a <see cref="TrainContext"/> constructed with <paramref name="trainData"/>.
/// </summary>
/// <param name="data">Training dataset</param>
/// <param name="validData">Validation dataset</param>
/// <param name="predictor">A trained predictor</param>
void Train(TDataSet data, TDataSet validData, TPredictor predictor);
/// <param name="trainer">The trainer</param>
/// <param name="trainData">The training data.</param>
/// <returns>The trained predictor</returns>
public static TPredictor Train<TPredictor>(this ITrainer<TPredictor> trainer, RoleMappedData trainData) where TPredictor : IPredictor
=> trainer.Train(new TrainContext(trainData));
}

#if FUTURE
public interface IMultiTrainer<in TDataSet, in TFeatures, out TResult> :
IMultiTrainer<TDataSet, TDataSet, TFeatures, TResult>
{
}

public interface IMultiTrainer<in TDataSet, in TDataBatch, in TFeatures, out TResult> :
ITrainer<TDataSet, TFeatures, TResult>
// A trainer can optionally implement this to indicate it can combine multiple models into a single predictor.
public interface IModelCombiner<TModel, TPredictor>
where TPredictor : IPredictor
{
void UpdatePredictor(TDataBatch trainInstance);
IPredictor<TFeatures, TResult> GetCurrentPredictor();
TPredictor CombineModels(IEnumerable<TModel> models);
}
#endif
}
57 changes: 57 additions & 0 deletions src/Microsoft.ML.Core/Prediction/TrainContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime.Data;

namespace Microsoft.ML.Runtime
{
/// <summary>
/// Holds information relevant to trainers. Instances of this class are meant to be constructed and passed
/// into <see cref="ITrainer{TPredictor}.Train(TrainContext)"/> or <see cref="ITrainer.Train(TrainContext)"/>.
/// This holds at least a training set, as well as optioonally a predictor.
/// </summary>
public sealed class TrainContext
{
/// <summary>
/// The training set. Cannot be <c>null</c>.
/// </summary>
public RoleMappedData TrainingSet { get; }

/// <summary>
/// The validation set. Can be <c>null</c>. Note that passing a non-<c>null</c> validation set into
/// a trainer that does not support validation sets should not be considered an error condition. It
/// should simply be ignored in that case.
/// </summary>
public RoleMappedData ValidationSet { get; }

/// <summary>
/// The initial predictor, for incremental training. Note that if a <see cref="ITrainer"/> implementor
/// does not support incremental training, then it can ignore it similarly to how one would ignore
/// <see cref="ValidationSet"/>. However, if the trainer does support incremental training and there
/// is something wrong with a non-<c>null</c> value of this, then the trainer ought to throw an exception.
/// </summary>
public IPredictor InitialPredictor { get; }


/// <summary>
/// Constructor, given a training set and optional other arguments.
/// </summary>
/// <param name="trainingSet">Will set <see cref="TrainingSet"/> to this value. This must be specified</param>
/// <param name="validationSet">Will set <see cref="ValidationSet"/> to this value if specified</param>
/// <param name="initialPredictor">Will set <see cref="InitialPredictor"/> to this value if specified</param>
public TrainContext(RoleMappedData trainingSet, RoleMappedData validationSet = null, IPredictor initialPredictor = null)
{
Contracts.CheckValue(trainingSet, nameof(trainingSet));
Contracts.CheckValueOrNull(validationSet);
Contracts.CheckValueOrNull(initialPredictor);

// REVIEW: Should there be code here to ensure that the role mappings between the two are compatible?
// That is, all the role mappings are the same and the columns between them have identical types?

TrainingSet = trainingSet;
ValidationSet = validationSet;
InitialPredictor = initialPredictor;
}
}
}
71 changes: 71 additions & 0 deletions src/Microsoft.ML.Core/Prediction/TrainerInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.ML.Runtime
{
/// <summary>
/// Instances of this class posses information about trainers, in terms of their requirements and capabilities.
/// The intended usage is as the value for <see cref="ITrainer.Info"/>.
/// </summary>
public sealed class TrainerInfo
{
// REVIEW: Ideally trainers should be able to communicate
// something about the type of data they are capable of being trained
// on, e.g., what ColumnKinds they want, how many of each, of what type,
// etc. This interface seems like the most natural conduit for that sort
// of extra information.

/// <summary>
/// Whether the trainer needs to see data in normalized form. Only non-parametric learners will tend to produce
/// normalization here.
/// </summary>
public bool NeedNormalization { get; }

/// <summary>
/// Whether the trainer needs calibration to produce probabilities. As a general rule only trainers that produce
/// binary classifier predictors that also do not have a natural probabilistic interpretation should have a
/// <c>true</c> value here.
/// </summary>
public bool NeedCalibration { get; }

/// <summary>
/// Whether this trainer could benefit from a cached view of the data. Trainers that have few passes over the
/// data, or that need to build their own custom data structure over the data, will have a <c>false</c> here.
/// </summary>
public bool WantCaching { get; }

/// <summary>
/// Whether the trainer supports validation sets via <see cref="TrainContext.ValidationSet"/>. Not implementing
/// this interface and returning <c>true</c> from this property is an indication the trainer does not support
/// that.
/// </summary>
public bool SupportsValidation { get; }

/// <summary>
/// Whether the trainer can support incremental trainers via <see cref="TrainContext.InitialPredictor"/>. Not
/// implementing this interface and returning <c>true</c> from this property is an indication the trainer does
/// not support that.
/// </summary>
public bool SupportsIncrementalTraining { get; }

/// <summary>
/// Initializes with the given parameters. The parameters have default values for the most typical values
/// for most classical trainers.
/// </summary>
/// <param name="normalization">The value for the property <see cref="NeedNormalization"/></param>
/// <param name="calibration">The value for the property <see cref="NeedCalibration"/></param>
/// <param name="caching">The value for the property <see cref="WantCaching"/></param>
/// <param name="supportValid">The value for the property <see cref="SupportsValidation"/></param>
/// <param name="supportIncrementalTrain">The value for the property <see cref="SupportsIncrementalTraining"/></param>
public TrainerInfo(bool normalization = true, bool calibration = false, bool caching = true,
bool supportValid = false, bool supportIncrementalTrain = false)
{
NeedNormalization = normalization;
NeedCalibration = calibration;
WantCaching = caching;
SupportsValidation = supportValid;
SupportsIncrementalTraining = supportIncrementalTrain;
}
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Core/Utilities/ObjectPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public abstract class ObjectPoolBase<T>
public int Count => _pool.Count;
public int NumCreated { get { return _numCreated; } }

protected internal ObjectPoolBase()
private protected ObjectPoolBase()
{
_pool = new ConcurrentBag<T>();
}
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.Core/Utilities/VBufferUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,11 +1147,11 @@ private static void ApplyWithCoreCopy<TSrc, TDst>(ref VBuffer<TSrc> src, ref VBu
/// storing the result in <paramref name="dst"/>, overwriting any of its existing contents.
/// The contents of <paramref name="dst"/> do not affect calculation. If you instead wish
/// to calculate a function that reads and writes <paramref name="dst"/>, see
/// <see cref="ApplyWith"/> and <see cref="ApplyWithEitherDefined"/>. Post-operation,
/// <see cref="ApplyWith{TSrc,TDst}"/> and <see cref="ApplyWithEitherDefined{TSrc,TDst}"/>. Post-operation,
/// <paramref name="dst"/> will be dense iff <paramref name="src"/> is dense.
/// </summary>
/// <seealso cref="ApplyWith"/>
/// <seealso cref="ApplyWithEitherDefined"/>
/// <seealso cref="ApplyWith{TSrc,TDst}"/>
/// <seealso cref="ApplyWithEitherDefined{TSrc,TDst}"/>
public static void ApplyIntoEitherDefined<TSrc, TDst>(ref VBuffer<TSrc> src, ref VBuffer<TDst> dst, Func<int, TSrc, TDst> func)
{
Contracts.CheckValue(func, nameof(func));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ private FoldResult RunFold(int fold)
if (_getValidationDataView != null)
{
ch.Assert(_applyTransformsToValidationData != null);
if (!TrainUtils.CanUseValidationData(trainer))
if (!trainer.Info.SupportsValidation)
ch.Warning("Trainer does not accept validation dataset.");
else
{
Expand Down
Loading

0 comments on commit f6baa5b

Please sign in to comment.