forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add SrCnn entire API by implementing function (dotnet#5135)
* Draft PR for SrCnn batch detection API interface (#1) * POC Batch transform * SrCnn batch interface * Removed comment * Handled some APIreview comments. * Handled other review comments. * Resolved review comments. Added sample. Co-authored-by: Yael Dekel <yaeld@microsoft.com> * Implement SrCnn entire API by function * Fix bugs and add test * Resolve comments * Change names and add documentation * Handling review comments * Resolve the array allocating issue * Move modeler initializing to CreateBatch and other minor fix. * Fix 3 remaining comments * Fixed code analysis issue. * Fixed minor comments Co-authored-by: klausmh <klausmh@microsoft.com> Co-authored-by: Yael Dekel <yaeld@microsoft.com>
- Loading branch information
1 parent
d58e8d1
commit a4af0ec
Showing
5 changed files
with
1,238 additions
and
5 deletions.
There are no files selected for viewing
99 changes: 99 additions & 0 deletions
99
.../samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectEntireAnomalyBySrCnn.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.ML; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.TimeSeries; | ||
|
||
namespace Samples.Dynamic | ||
{ | ||
public static class DetectEntireAnomalyBySrCnn | ||
{ | ||
public static void Example() | ||
{ | ||
// Create a new ML context, for ML.NET operations. It can be used for | ||
// exception tracking and logging, | ||
// as well as the source of randomness. | ||
var ml = new MLContext(); | ||
|
||
// Generate sample series data with an anomaly | ||
var data = new List<TimeSeriesData>(); | ||
for (int index = 0; index < 20; index++) | ||
{ | ||
data.Add(new TimeSeriesData { Value = 5 }); | ||
} | ||
data.Add(new TimeSeriesData { Value = 10 }); | ||
for (int index = 0; index < 5; index++) | ||
{ | ||
data.Add(new TimeSeriesData { Value = 5 }); | ||
} | ||
|
||
// Convert data to IDataView. | ||
var dataView = ml.Data.LoadFromEnumerable(data); | ||
|
||
// Setup the detection arguments | ||
string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction); | ||
string inputColumnName = nameof(TimeSeriesData.Value); | ||
|
||
// Do batch anomaly detection | ||
var outputDataView = ml.AnomalyDetection.DetectEntireAnomalyBySrCnn(dataView, outputColumnName, inputColumnName, | ||
threshold: 0.35, batchSize: 512, sensitivity: 90.0, detectMode: SrCnnDetectMode.AnomalyAndMargin); | ||
|
||
// Getting the data of the newly created column as an IEnumerable of | ||
// SrCnnAnomalyDetection. | ||
var predictionColumn = ml.Data.CreateEnumerable<SrCnnAnomalyDetection>( | ||
outputDataView, reuseRowObject: false); | ||
|
||
Console.WriteLine("Index\tData\tAnomaly\tAnomalyScore\tMag\tExpectedValue\tBoundaryUnit\tUpperBoundary\tLowerBoundary"); | ||
|
||
int k = 0; | ||
foreach (var prediction in predictionColumn) | ||
{ | ||
PrintPrediction(k, data[k].Value, prediction); | ||
k++; | ||
} | ||
//Index Data Anomaly AnomalyScore Mag ExpectedValue BoundaryUnit UpperBoundary LowerBoundary | ||
//0 5.00 0 0.00 0.21 5.00 5.00 5.01 4.99 | ||
//1 5.00 0 0.00 0.11 5.00 5.00 5.01 4.99 | ||
//2 5.00 0 0.00 0.03 5.00 5.00 5.01 4.99 | ||
//3 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99 | ||
//4 5.00 0 0.00 0.03 5.00 5.00 5.01 4.99 | ||
//5 5.00 0 0.00 0.06 5.00 5.00 5.01 4.99 | ||
//6 5.00 0 0.00 0.02 5.00 5.00 5.01 4.99 | ||
//7 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99 | ||
//8 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99 | ||
//9 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99 | ||
//10 5.00 0 0.00 0.00 5.00 5.00 5.01 4.99 | ||
//11 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99 | ||
//12 5.00 0 0.00 0.01 5.00 5.00 5.01 4.99 | ||
//13 5.00 0 0.00 0.02 5.00 5.00 5.01 4.99 | ||
//14 5.00 0 0.00 0.07 5.00 5.00 5.01 4.99 | ||
//15 5.00 0 0.00 0.08 5.00 5.00 5.01 4.99 | ||
//16 5.00 0 0.00 0.02 5.00 5.00 5.01 4.99 | ||
//17 5.00 0 0.00 0.05 5.00 5.00 5.01 4.99 | ||
//18 5.00 0 0.00 0.12 5.00 5.00 5.01 4.99 | ||
//19 5.00 0 0.00 0.17 5.00 5.00 5.01 4.99 | ||
//20 10.00 1 0.50 0.80 5.00 5.00 5.01 4.99 | ||
//21 5.00 0 0.00 0.16 5.00 5.00 5.01 4.99 | ||
//22 5.00 0 0.00 0.11 5.00 5.00 5.01 4.99 | ||
//23 5.00 0 0.00 0.05 5.00 5.00 5.01 4.99 | ||
//24 5.00 0 0.00 0.11 5.00 5.00 5.01 4.99 | ||
//25 5.00 0 0.00 0.19 5.00 5.00 5.01 4.99 | ||
} | ||
|
||
private static void PrintPrediction(int idx, double value, SrCnnAnomalyDetection prediction) => | ||
Console.WriteLine("{0}\t{1:0.00}\t{2}\t\t{3:0.00}\t{4:0.00}\t\t{5:0.00}\t\t{6:0.00}\t\t{7:0.00}\t\t{8:0.00}", | ||
idx, value, prediction.Prediction[0], prediction.Prediction[1], prediction.Prediction[2], | ||
prediction.Prediction[3], prediction.Prediction[4], prediction.Prediction[5], prediction.Prediction[6]); | ||
|
||
private class TimeSeriesData | ||
{ | ||
public double Value { get; set; } | ||
} | ||
|
||
private class SrCnnAnomalyDetection | ||
{ | ||
[VectorType] | ||
public double[] Prediction { get; set; } | ||
} | ||
} | ||
} |
173 changes: 173 additions & 0 deletions
173
src/Microsoft.ML.Data/DataView/BatchDataViewMapperBase.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.ML.Runtime; | ||
|
||
namespace Microsoft.ML.Data.DataView | ||
{ | ||
internal abstract class BatchDataViewMapperBase<TInput, TBatch> : IDataView | ||
{ | ||
public bool CanShuffle => false; | ||
|
||
public DataViewSchema Schema => SchemaBindings.AsSchema; | ||
|
||
private readonly IDataView _source; | ||
protected readonly IHost Host; | ||
|
||
protected BatchDataViewMapperBase(IHostEnvironment env, string registrationName, IDataView input) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
Host = env.Register(registrationName); | ||
_source = input; | ||
} | ||
|
||
public long? GetRowCount() => _source.GetRowCount(); | ||
|
||
public DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null) | ||
{ | ||
Host.CheckValue(columnsNeeded, nameof(columnsNeeded)); | ||
Host.CheckValueOrNull(rand); | ||
|
||
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, SchemaBindings.AsSchema); | ||
|
||
// If we aren't selecting any of the output columns, don't construct our cursor. | ||
// Note that because we cannot support random due to the inherently | ||
// stratified nature, neither can we allow the base data to be shuffled, | ||
// even if it supports shuffling. | ||
if (!SchemaBindings.AnyNewColumnsActive(predicate)) | ||
{ | ||
var activeInput = SchemaBindings.GetActiveInput(predicate); | ||
var inputCursor = _source.GetRowCursor(_source.Schema.Where(c => activeInput[c.Index]), null); | ||
return new BindingsWrappedRowCursor(Host, inputCursor, SchemaBindings); | ||
} | ||
var active = SchemaBindings.GetActive(predicate); | ||
Contracts.Assert(active.Length == SchemaBindings.ColumnCount); | ||
|
||
// REVIEW: We can get a different input predicate for the input cursor and for the lookahead cursor. The lookahead | ||
// cursor is only used for getting the values from the input column, so it only needs that column activated. The | ||
// other cursor is used to get source columns, so it needs the rest of them activated. | ||
var predInput = GetSchemaBindingDependencies(predicate); | ||
var inputCols = _source.Schema.Where(c => predInput(c.Index)); | ||
return new Cursor(this, _source.GetRowCursor(inputCols), _source.GetRowCursor(inputCols), active); | ||
} | ||
|
||
public DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null) | ||
{ | ||
return new[] { GetRowCursor(columnsNeeded, rand) }; | ||
} | ||
|
||
protected abstract ColumnBindingsBase SchemaBindings { get; } | ||
protected abstract TBatch CreateBatch(DataViewRowCursor input); | ||
protected abstract void ProcessBatch(TBatch currentBatch); | ||
protected abstract void ProcessExample(TBatch currentBatch, TInput currentInput); | ||
protected abstract Func<bool> GetLastInBatchDelegate(DataViewRowCursor lookAheadCursor); | ||
protected abstract Func<bool> GetIsNewBatchDelegate(DataViewRowCursor lookAheadCursor); | ||
protected abstract ValueGetter<TInput> GetLookAheadGetter(DataViewRowCursor lookAheadCursor); | ||
protected abstract Delegate[] CreateGetters(DataViewRowCursor input, TBatch currentBatch, bool[] active); | ||
protected abstract Func<int, bool> GetSchemaBindingDependencies(Func<int, bool> predicate); | ||
|
||
private sealed class Cursor : RootCursorBase | ||
{ | ||
private readonly BatchDataViewMapperBase<TInput, TBatch> _parent; | ||
private readonly DataViewRowCursor _lookAheadCursor; | ||
private readonly DataViewRowCursor _input; | ||
|
||
private readonly bool[] _active; | ||
private readonly Delegate[] _getters; | ||
|
||
private readonly TBatch _currentBatch; | ||
private readonly Func<bool> _lastInBatchInLookAheadCursorDel; | ||
private readonly Func<bool> _firstInBatchInInputCursorDel; | ||
private readonly ValueGetter<TInput> _inputGetterInLookAheadCursor; | ||
private TInput _currentInput; | ||
|
||
public override long Batch => 0; | ||
|
||
public override DataViewSchema Schema => _parent.Schema; | ||
|
||
public Cursor(BatchDataViewMapperBase<TInput, TBatch> parent, DataViewRowCursor input, DataViewRowCursor lookAheadCursor, bool[] active) | ||
: base(parent.Host) | ||
{ | ||
_parent = parent; | ||
_input = input; | ||
_lookAheadCursor = lookAheadCursor; | ||
_active = active; | ||
|
||
_currentBatch = _parent.CreateBatch(_input); | ||
|
||
_getters = _parent.CreateGetters(_input, _currentBatch, _active); | ||
|
||
_lastInBatchInLookAheadCursorDel = _parent.GetLastInBatchDelegate(_lookAheadCursor); | ||
_firstInBatchInInputCursorDel = _parent.GetIsNewBatchDelegate(_input); | ||
_inputGetterInLookAheadCursor = _parent.GetLookAheadGetter(_lookAheadCursor); | ||
} | ||
|
||
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column) | ||
{ | ||
Contracts.CheckParam(IsColumnActive(column), nameof(column), "requested column is not active"); | ||
|
||
var col = _parent.SchemaBindings.MapColumnIndex(out bool isSrc, column.Index); | ||
if (isSrc) | ||
{ | ||
Contracts.AssertValue(_input); | ||
return _input.GetGetter<TValue>(_input.Schema[col]); | ||
} | ||
|
||
Ch.AssertValue(_getters); | ||
var getter = _getters[col]; | ||
Ch.Assert(getter != null); | ||
var fn = getter as ValueGetter<TValue>; | ||
if (fn == null) | ||
throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); | ||
return fn; | ||
} | ||
|
||
public override ValueGetter<DataViewRowId> GetIdGetter() | ||
{ | ||
return | ||
(ref DataViewRowId val) => | ||
{ | ||
Ch.Check(IsGood, "Cannot call ID getter in current state"); | ||
val = new DataViewRowId((ulong)Position, 0); | ||
}; | ||
} | ||
|
||
public override bool IsColumnActive(DataViewSchema.Column column) | ||
{ | ||
Ch.Check(column.Index < _parent.SchemaBindings.AsSchema.Count); | ||
return _active[column.Index]; | ||
} | ||
|
||
protected override bool MoveNextCore() | ||
{ | ||
if (!_input.MoveNext()) | ||
return false; | ||
if (!_firstInBatchInInputCursorDel()) | ||
return true; | ||
|
||
// If we are here, this means that _input.MoveNext() has gotten us to the beginning of the next batch, | ||
// so now we need to look ahead at the entire next batch in the _lookAheadCursor. | ||
// The _lookAheadCursor's position should be on the last row of the previous batch (or -1). | ||
Ch.Assert(_lastInBatchInLookAheadCursorDel()); | ||
|
||
var good = _lookAheadCursor.MoveNext(); | ||
// The two cursors should have the same number of elements, so if _input.MoveNext() returned true, | ||
// then it must return true here too. | ||
Ch.Assert(good); | ||
|
||
do | ||
{ | ||
_inputGetterInLookAheadCursor(ref _currentInput); | ||
_parent.ProcessExample(_currentBatch, _currentInput); | ||
} while (!_lastInBatchInLookAheadCursorDel() && _lookAheadCursor.MoveNext()); | ||
|
||
_parent.ProcessBatch(_currentBatch); | ||
return true; | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.