-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
ComponentCreation.cs
453 lines (415 loc) · 24.2 KB
/
ComponentCreation.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Model;
namespace Microsoft.ML.Runtime.Api
{
/// <summary>
/// This class defines extension methods for an <see cref="IHostEnvironment"/> to facilitate creating
/// components (loaders, transforms, trainers, scorers, evaluators, savers).
/// </summary>
public static class ComponentCreation
{
/// <summary>
/// Create a new data view which is obtained by appending all columns of all the source data views.
/// If the data views are of different length, the resulting data view will have the length equal to the
/// length of the shortest source.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="sources">A non-empty collection of data views to zip together.</param>
/// <returns>The resulting data view.</returns>
public static IDataView Zip(this IHostEnvironment env, IEnumerable<IDataView> sources)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(sources, nameof(sources));
return ZipDataView.Create(env, sources);
}
/// <summary>
/// Generate training examples for training a predictor or instantiating a scorer.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="data">The data to use for training or scoring.</param>
/// <param name="features">The name of the features column. Can be null.</param>
/// <param name="label">The name of the label column. Can be null.</param>
/// <param name="group">The name of the group ID column (for ranking). Can be null.</param>
/// <param name="weight">The name of the weight column. Can be null.</param>
/// <param name="custom">Additional column mapping to be passed to the trainer or scorer (specific to the prediction type). Can be null or empty.</param>
/// <returns>The constructed examples.</returns>
public static RoleMappedData CreateExamples(this IHostEnvironment env, IDataView data, string features, string label = null,
string group = null, string weight = null, IEnumerable<KeyValuePair<RoleMappedSchema.ColumnRole, string>> custom = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValueOrNull(label);
env.CheckValueOrNull(features);
env.CheckValueOrNull(group);
env.CheckValueOrNull(weight);
env.CheckValueOrNull(custom);
return new RoleMappedData(data, label, features, group, weight, name: null, custom: custom);
}
/// <summary>
/// Create a new <see cref="IDataView"/> over an in-memory collection of the items of user-defined type.
/// The user maintains ownership of the <paramref name="data"/> and the resulting data view will
/// never alter the contents of the <paramref name="data"/>.
/// Since <see cref="IDataView"/> is assumed to be immutable, the user is expected to not
/// modify the contents of <paramref name="data"/> while the data view is being actively cursored.
///
/// One typical usage for in-memory data view could be: create the data view, train a predictor.
/// Once the predictor is fully trained, modify the contents of the underlying collection and
/// train another predictor.
/// </summary>
/// <typeparam name="TRow">The user-defined item type.</typeparam>
/// <param name="env">The host environment to use for data view creation.</param>
/// <param name="data">The data to wrap around.</param>
/// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
/// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
/// <returns>The constructed <see cref="IDataView"/>.</returns>
public static IDataView CreateDataView<TRow>(this IHostEnvironment env, IList<TRow> data, SchemaDefinition schemaDefinition = null)
where TRow : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(data, nameof(data));
env.CheckValueOrNull(schemaDefinition);
return DataViewConstructionUtils.CreateFromList(env, data, schemaDefinition);
}
/// <summary>
/// Create a new <see cref="IDataView"/> over an enumerable of the items of user-defined type.
/// The user maintains ownership of the <paramref name="data"/> and the resulting data view will
/// never alter the contents of the <paramref name="data"/>.
/// Since <see cref="IDataView"/> is assumed to be immutable, the user is expected to support
/// multiple enumeration of the <paramref name="data"/> that would return the same results, unless
/// the user knows that the data will only be cursored once.
///
/// One typical usage for streaming data view could be: create the data view that lazily loads data
/// as needed, then apply pre-trained transformations to it and cursor through it for transformation
/// results. This is how <see cref="BatchPredictionEngine{TSrc,TDst}"/> is implemented.
/// </summary>
/// <typeparam name="TRow">The user-defined item type.</typeparam>
/// <param name="env">The host environment to use for data view creation.</param>
/// <param name="data">The data to wrap around.</param>
/// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
/// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
/// <returns>The constructed <see cref="IDataView"/>.</returns>
public static IDataView CreateStreamingDataView<TRow>(this IHostEnvironment env, IEnumerable<TRow> data, SchemaDefinition schemaDefinition = null)
where TRow : class
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(data, nameof(data));
env.CheckValueOrNull(schemaDefinition);
return DataViewConstructionUtils.CreateFromEnumerable(env, data, schemaDefinition);
}
/// <summary>
/// Create a batch prediction engine.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="modelStream">The stream to deserialize the pipeline (transforms and predictor) from.</param>
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
public static BatchPredictionEngine<TSrc, TDst> CreateBatchPredictionEngine<TSrc, TDst>(this IHostEnvironment env, Stream modelStream,
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(modelStream, nameof(modelStream));
env.CheckValueOrNull(inputSchemaDefinition);
env.CheckValueOrNull(outputSchemaDefinition);
return new BatchPredictionEngine<TSrc, TDst>(env, modelStream, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}
/// <summary>
/// Create a batch prediction engine.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="dataPipe">The transformation pipe that may or may not include a scorer.</param>
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
public static BatchPredictionEngine<TSrc, TDst> CreateBatchPredictionEngine<TSrc, TDst>(this IHostEnvironment env, IDataView dataPipe,
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(dataPipe, nameof(dataPipe));
env.CheckValueOrNull(inputSchemaDefinition);
env.CheckValueOrNull(outputSchemaDefinition);
return new BatchPredictionEngine<TSrc, TDst>(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}
/// <summary>
/// Create an on-demand prediction engine.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="modelStream">The stream to deserialize the pipeline (transforms and predictor) from.</param>
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, Stream modelStream,
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(modelStream, nameof(modelStream));
env.CheckValueOrNull(inputSchemaDefinition);
env.CheckValueOrNull(outputSchemaDefinition);
return new PredictionEngine<TSrc, TDst>(env, modelStream, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}
/// <summary>
/// Create an on-demand prediction engine.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="dataPipe">The transformation pipe that may or may not include a scorer.</param>
/// <param name="ignoreMissingColumns">Whether to ignore missing columns in the data view.</param>
/// <param name="inputSchemaDefinition">The optional input schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TSrc"/> type.</param>
/// <param name="outputSchemaDefinition">The optional output schema. If <c>null</c>, the schema is inferred from the <typeparamref name="TDst"/> type.</param>
public static PredictionEngine<TSrc, TDst> CreatePredictionEngine<TSrc, TDst>(this IHostEnvironment env, IDataView dataPipe,
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(dataPipe, nameof(dataPipe));
env.CheckValueOrNull(inputSchemaDefinition);
env.CheckValueOrNull(outputSchemaDefinition);
return new PredictionEngine<TSrc, TDst>(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}
/// <summary>
/// Create a prediction engine.
/// This encapsulates the 'classic' prediction problem, where the input is denoted by the float array of features,
/// and the output is a float score. For binary classification predictors that can output probability, there are output
/// fields that report the predicted label and probability.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="modelStream">The model stream to load pipeline from.</param>
/// <param name="nFeatures">Number of features.</param>
public static SimplePredictionEngine CreateSimplePredictionEngine(this IHostEnvironment env, Stream modelStream, int nFeatures)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(modelStream, nameof(modelStream));
env.CheckParam(nFeatures > 0, nameof(nFeatures), "Number of features must be positive.");
return new SimplePredictionEngine(env, modelStream, nFeatures);
}
/// <summary>
/// Load the transforms (but not loader) from the model steram and apply them to the specified data.
/// It is acceptable to have no transforms in the model stream: in this case the original
/// <paramref name="data"/> will be returned.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="modelStream">The model stream to load from.</param>
/// <param name="data">The data to apply transforms to.</param>
/// <returns>The transformed data.</returns>
public static IDataView LoadTransforms(this IHostEnvironment env, Stream modelStream, IDataView data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(modelStream, nameof(modelStream));
env.CheckValue(data, nameof(data));
return ModelFileUtils.LoadTransforms(env, data, modelStream);
}
// REVIEW: Add one more overload that works off SubComponents.
/// <summary>
/// Creates a data loader from the arguments object.
/// </summary>
public static IDataLoader CreateLoader<TArgs>(this IHostEnvironment env, TArgs arguments, IMultiStreamSource files)
where TArgs : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(files, nameof(files));
return CreateCore<IDataLoader, TArgs, SignatureDataLoader>(env, arguments, files);
}
/// <summary>
/// Creates a data loader from the 'LoadName{settings}' string.
/// </summary>
public static IDataLoader CreateLoader(this IHostEnvironment env, string settings, IMultiStreamSource files)
{
Contracts.CheckValue(env, nameof(env));
Contracts.CheckValue(files, nameof(files));
Type factoryType = typeof(IComponentFactory<IMultiStreamSource, IDataLoader>);
return CreateCore<IDataLoader>(env, factoryType, typeof(SignatureDataLoader), settings, files);
}
/// <summary>
/// Creates a data saver from the arguments object.
/// </summary>
public static IDataSaver CreateSaver<TArgs>(this IHostEnvironment env, TArgs arguments)
where TArgs : class, new()
{
Contracts.CheckValue(env, nameof(env));
return CreateCore<IDataSaver, TArgs, SignatureDataSaver>(env, arguments);
}
/// <summary>
/// Creates a data saver from the 'LoadName{settings}' string.
/// </summary>
public static IDataSaver CreateSaver(this IHostEnvironment env, string settings)
{
Contracts.CheckValue(env, nameof(env));
return CreateCore<IDataSaver>(env, typeof(SignatureDataSaver), settings);
}
/// <summary>
/// Creates a data transform from the arguments object.
/// </summary>
public static IDataTransform CreateTransform<TArgs>(this IHostEnvironment env, TArgs arguments, IDataView source)
where TArgs : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(source, nameof(source));
return CreateCore<IDataTransform, TArgs, SignatureDataTransform>(env, arguments, source);
}
/// <summary>
/// Creates a data transform from the 'LoadName{settings}' string.
/// </summary>
public static IDataTransform CreateTransform(this IHostEnvironment env, string settings, IDataView source)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(source, nameof(source));
Type factoryType = typeof(IComponentFactory<IDataView, IDataTransform>);
return CreateCore<IDataTransform>(env, factoryType, typeof(SignatureDataTransform), settings, source);
}
/// <summary>
/// Creates a data scorer from the 'LoadName{settings}' string.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="settings">The settings string.</param>
/// <param name="data">The data to score.</param>
/// <param name="predictor">The predictor to score.</param>
/// <param name="trainSchema">The training data schema from which the scorer can optionally extract
/// additional information, e.g., label names. If this is <c>null</c>, no information will be
/// extracted.</param>
/// <returns>The scored data.</returns>
public static IDataScorerTransform CreateScorer(this IHostEnvironment env, string settings,
RoleMappedData data, Predictor predictor, RoleMappedSchema trainSchema = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(data, nameof(data));
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(trainSchema);
Type factoryType = typeof(IComponentFactory<IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform>);
Type signatureType = typeof(SignatureDataScorer);
ICommandLineComponentFactory scorerFactorySettings = CmdParser.CreateComponentFactory(
factoryType,
signatureType,
settings);
var bindable = ScoreUtils.GetSchemaBindableMapper(env, predictor.Pred, scorerFactorySettings: scorerFactorySettings);
var mapper = bindable.Bind(env, data.Schema);
return CreateCore<IDataScorerTransform>(env, factoryType, signatureType, settings, data.Data, mapper, trainSchema);
}
/// <summary>
/// Creates a default data scorer appropriate to the predictor's prediction kind.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="data">The data to score.</param>
/// <param name="predictor">The predictor to score.</param>
/// <param name="trainSchema">The training data schema from which the scorer can optionally extract
/// additional information, e.g., label names. If this is <c>null</c>, no information will be
/// extracted.</param>
/// <returns>The scored data.</returns>
public static IDataScorerTransform CreateDefaultScorer(this IHostEnvironment env, RoleMappedData data,
Predictor predictor, RoleMappedSchema trainSchema = null)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(data, nameof(data));
env.CheckValue(predictor, nameof(predictor));
env.CheckValueOrNull(trainSchema);
return ScoreUtils.GetScorer(predictor.Pred, data, env, trainSchema);
}
public static IEvaluator CreateEvaluator(this IHostEnvironment env, string settings)
{
Contracts.CheckValue(env, nameof(env));
env.CheckNonWhiteSpace(settings, nameof(settings));
return CreateCore<IEvaluator>(env, typeof(SignatureEvaluator), settings);
}
/// <summary>
/// Loads a predictor from the model stream. Returns null iff there's no predictor.
/// </summary>
/// <param name="env">The host environment to use.</param>
/// <param name="modelStream">The model stream.</param>
public static Predictor LoadPredictorOrNull(this IHostEnvironment env, Stream modelStream)
{
Contracts.CheckValue(modelStream, nameof(modelStream));
var p = ModelFileUtils.LoadPredictorOrNull(env, modelStream);
return p == null ? null : new Predictor(p);
}
internal static ITrainer CreateTrainer<TArgs>(this IHostEnvironment env, TArgs arguments, out string loadName)
where TArgs : class, new()
{
Contracts.CheckValue(env, nameof(env));
return CreateCore<ITrainer, TArgs, SignatureTrainer>(env, arguments, out loadName);
}
internal static ITrainer CreateTrainer(this IHostEnvironment env, string settings, out string loadName)
{
Contracts.CheckValue(env, nameof(env));
return CreateCore<ITrainer>(env, typeof(SignatureTrainer), settings, out loadName);
}
private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type signatureType,
string settings,
params object[] extraArgs)
where TRes : class
{
return CreateCore<TRes>(env, signatureType, settings, out string loadName, extraArgs);
}
private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type signatureType,
string settings,
out string loadName,
params object[] extraArgs)
where TRes : class
{
return CreateCore<TRes>(env, typeof(IComponentFactory<TRes>), signatureType, settings, out loadName, extraArgs);
}
private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type factoryType,
Type signatureType,
string settings,
params object[] extraArgs)
where TRes : class
{
string loadName;
return CreateCore<TRes>(env, factoryType, signatureType, settings, out loadName, extraArgs);
}
private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, params object[] extraArgs)
where TRes : class
where TArgs : class, new()
{
string loadName;
return CreateCore<TRes, TArgs, TSig>(env, args, out loadName, extraArgs);
}
private static TRes CreateCore<TRes>(
IHostEnvironment env,
Type factoryType,
Type signatureType,
string settings,
out string loadName,
params object[] extraArgs)
where TRes : class
{
Contracts.AssertValue(env);
env.AssertValue(factoryType);
env.AssertValue(signatureType);
env.AssertValue(settings, "settings");
var factory = CmdParser.CreateComponentFactory(factoryType, signatureType, settings);
loadName = factory.Name;
return ComponentCatalog.CreateInstance<TRes>(env, factory.SignatureType, factory.Name, factory.GetSettingsString(), extraArgs);
}
private static TRes CreateCore<TRes, TArgs, TSig>(IHostEnvironment env, TArgs args, out string loadName, params object[] extraArgs)
where TRes : class
where TArgs : class, new()
{
env.CheckValue(args, nameof(args));
var classes = ComponentCatalog.FindLoadableClasses<TArgs, TSig>();
if (classes.Length == 0)
throw env.Except("Couldn't find a {0} class that accepts {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName);
if (classes.Length > 1)
throw env.Except("Found too many {0} classes that accept {1} as arguments.", typeof(TRes).Name, typeof(TArgs).FullName);
var lc = classes[0];
loadName = lc.LoadNames[0];
return lc.CreateInstance<TRes>(env, args, extraArgs);
}
}
}