This repository has been archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 63
/
Bridge.cs
510 lines (431 loc) · 24.2 KB
/
Bridge.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
//------------------------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//------------------------------------------------------------------------------
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.ImageAnalytics;
using Microsoft.ML.Runtime.KMeans;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Runtime.Model.Onnx;
using Microsoft.ML.Runtime.PCA;
using Microsoft.ML.Runtime.PipelineInference;
using Microsoft.ML.Runtime.SymSgd;
using Microsoft.ML.Transforms;
namespace Microsoft.MachineLearning.DotNetBridge
{
/// <summary>
/// The main entry point from native code. Note that GC / lifetime issues are critical to get correct.
/// This code shares a bunch of information with native code via unsafe pointers. We need to carefully
/// ensure that no native pointers are accessed after the invoked entry point returns. As an example
/// of an implication of this: "returned" data is provided via a call back function (from managed to\
/// native). It cannot simply be returned by the entry point since source data that the output data depends
/// on is not accessible once we return from managed code.
/// </summary>
public unsafe static partial class Bridge
{
// For getting float values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void R4Getter(DataSourceBlock* pv, int col, long index, out float dst);
// For getting double values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void R8Getter(DataSourceBlock* pv, int col, long index, out double dst);
// For getting bool values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void BLGetter(DataSourceBlock* pv, int col, long index, out sbyte dst);
// For getting numpy.int8 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I1Getter(DataSourceBlock* pv, int col, long index, out sbyte dst);
// For getting numpy.int16 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I2Getter(DataSourceBlock* pv, int col, long index, out short dst);
// For getting numpy.int32 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I4Getter(DataSourceBlock* pv, int col, long index, out int dst);
// For getting numpy.int64 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I8Getter(DataSourceBlock* pv, int col, long index, out long dst);
// For getting numpy.uint8 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U1Getter(DataSourceBlock* pv, int col, long index, out byte dst);
// For getting numpy.uint16 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U2Getter(DataSourceBlock* pv, int col, long index, out ushort dst);
// For getting numpy.uint32 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U4Getter(DataSourceBlock* pv, int col, long index, out uint dst);
// For getting numpy.uint64 values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U8Getter(DataSourceBlock* pv, int col, long index, out ulong dst);
// For getting string values from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void TXGetter(DataSourceBlock* pv, int col, long index, out char* pch, out int size, out int missing);
// For getting key-type labels. id specifies the column id, count is the key type cardinality, and buffer
// must be large enough for count pointers. Returns success.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate bool KeyNamesGetter(DataSourceBlock* pdata, int col, int count, sbyte** buffer);
// For getting numpy.int64 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I8VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, long* values, bool inquire, out int size);
// For getting numpy.int32 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I4VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, int* values, bool inquire, out int size);
// For getting numpy.int16 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I2VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, short* values, bool inquire, out int size);
// For getting numpy.int8 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I1VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, sbyte* values, bool inquire, out int size);
// For getting numpy.uint64 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U8VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, ulong* values, bool inquire, out int size);
// For getting numpy.uint32 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U4VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, uint* values, bool inquire, out int size);
// For getting numpy.uint16 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U2VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, ushort* values, bool inquire, out int size);
// For getting numpy.uint8 vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U1VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, byte* values, bool inquire, out int size);
// For getting numpy.bool vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void BLVectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, bool* values, bool inquire, out int size);
// For getting float vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void R4VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, float* values, bool inquire, out int size);
// For getting double vectors from NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void R8VectorGetter(DataSourceBlock* pdata, int col, long index, int* indices, double* values, bool inquire, out int size);
// For setting bool values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void BLSetter(EnvironmentBlock* penv, int col, long index, byte value);
// For setting float values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void R4Setter(EnvironmentBlock* penv, int col, long index, float value);
// For setting double values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void R8Setter(EnvironmentBlock* penv, int col, long index, double value);
// For setting I1 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I1Setter(EnvironmentBlock* penv, int col, long index, sbyte value);
// For setting I2 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I2Setter(EnvironmentBlock* penv, int col, long index, short value);
// For setting I4 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I4Setter(EnvironmentBlock* penv, int col, long index, int value);
// For setting I8 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void I8Setter(EnvironmentBlock* penv, int col, long index, long value);
// For setting U1 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U1Setter(EnvironmentBlock* penv, int col, long index, byte value);
// For setting U2 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U2Setter(EnvironmentBlock* penv, int col, long index, ushort value);
// For setting U4 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U4Setter(EnvironmentBlock* penv, int col, long index, uint value);
// For setting U8 values to NativeBridge.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void U8Setter(EnvironmentBlock* penv, int col, long index, ulong value);
// For setting string values, to a generic pointer and index.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void TXSetter(EnvironmentBlock* penv, int col, long index, sbyte* pch, int cch);
// For setting string key values, to a generic pointer and index.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void KeyValueSetter(EnvironmentBlock* penv, int keyId, int keyCode, sbyte* pch, int cch);
private enum FnId
{
HelloMlNet = 1,
Generic = 2,
}
#if !CORECLR
// The hosting code invokes this to get a specific entry point.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private delegate IntPtr NativeFnGetter(FnId id);
#endif
#region Callbacks to native
// Call back to provide messages to native code.
// REVIEW: Should we support embedded nulls? This API implies null termination.
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void MessageSink(EnvironmentBlock* penv, ChannelMessageKind kind, sbyte* sender, sbyte* message);
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void ModelSink(EnvironmentBlock* penv, byte* modelBytes, ulong modelSize);
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate void DataSink(EnvironmentBlock* penv, DataViewBlock* pdata, out void** setters, out void* keyValueSetter);
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
public unsafe delegate bool CheckCancelled();
#endregion Callbacks to native
/// <summary>
/// This is provided by the native code. It provides general call backs for the task.
/// Data source specific call backs are in another structure.
/// </summary>
[StructLayout(LayoutKind.Explicit)]
private struct EnvironmentBlock
{
#pragma warning disable 649 // never assigned
[FieldOffset(0x00)]
public readonly int verbosity;
[FieldOffset(0x04)]
public readonly int seed;
// Call back to provide messages to native code.
[FieldOffset(0x08)]
public readonly void* messageSink;
// Call back to provide data to native code.
[FieldOffset(0x10)]
public readonly void* dataSink;
// Call back to provide model to native code.
[FieldOffset(0x18)]
public readonly void* modelSink;
[FieldOffset(0x20)]
public readonly int maxThreadsAllowed;
// Call back to provide cancel flag.
[FieldOffset(0x28)]
public readonly void* checkCancel;
#pragma warning restore 649 // never assigned
}
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
private unsafe delegate int NativeGeneric(EnvironmentBlock* penv, sbyte* psz, int cdata, DataSourceBlock** ppdata);
#if !CORECLR
private static NativeFnGetter FnGetter;
#endif
private static NativeGeneric FnGeneric;
private static TDel MarshalDelegate<TDel>(void* pv)
{
Contracts.Assert(typeof(TDel).IsSubclassOf(typeof(Delegate)));
Contracts.Assert(pv != null);
#if CORECLR
return Marshal.GetDelegateForFunctionPointer<TDel>((IntPtr)pv);
#else
return (TDel)(object)Marshal.GetDelegateForFunctionPointer((IntPtr)pv, typeof(TDel));
#endif
}
#if !CORECLR
/// <summary>
/// This is the bootstrapping entry point. It's labeled private but is actually invoked from the native
/// code to poke the address of the FnGetter callback into the address encoded in the string parameter.
/// This odd way of doing things is because the most convenient way to call an initial managed method
/// imposes the signature of Func{string, int}, which doesn't allow us to return a function adress.
/// </summary>
private static unsafe int GetFnGetterCallback(string addr)
{
if (FnGetter == null)
Interlocked.CompareExchange(ref FnGetter, (NativeFnGetter)GetFn, null);
long a = long.Parse(addr);
IntPtr* p = null;
IntPtr** pp = &p;
*(long*)pp = a;
*p = Marshal.GetFunctionPointerForDelegate(FnGetter);
return 1;
}
#endif
/// <summary>
/// This is the main FnGetter function. Given an FnId value, it returns a native-callable
/// entry point address.
/// </summary>
private static unsafe IntPtr GetFn(FnId id)
{
switch (id)
{
default:
return default(IntPtr);
case FnId.Generic:
if (FnGeneric == null)
Interlocked.CompareExchange(ref FnGeneric, GenericExec, null);
return Marshal.GetFunctionPointerForDelegate(FnGeneric);
}
}
/// <summary>
// The Generic entry point. The specific behavior is indicated in a string argument.
/// </summary>
private static unsafe int GenericExec(EnvironmentBlock* penv, sbyte* psz, int cdata, DataSourceBlock** ppdata)
{
using (var env = new RmlEnvironment(MarshalDelegate<CheckCancelled>(penv->checkCancel), penv->seed,
verbose: penv != null && penv->verbosity > 3, conc: penv != null ? penv->maxThreadsAllowed : 0))
{
var host = env.Register("ML.NET_Execution");
env.ComponentCatalog.RegisterAssembly(typeof(TextLoader).Assembly); // ML.Data
env.ComponentCatalog.RegisterAssembly(typeof(LinearPredictor).Assembly); // ML.StandardLearners
env.ComponentCatalog.RegisterAssembly(typeof(CategoricalTransform).Assembly); // ML.Transforms
env.ComponentCatalog.RegisterAssembly(typeof(FastTreeBinaryPredictor).Assembly); // ML.FastTree
env.ComponentCatalog.RegisterAssembly(typeof(KMeansPredictor).Assembly); // ML.KMeansClustering
env.ComponentCatalog.RegisterAssembly(typeof(PcaPredictor).Assembly); // ML.PCA
env.ComponentCatalog.RegisterAssembly(typeof(Experiment).Assembly); // ML.Legacy
env.ComponentCatalog.RegisterAssembly(typeof(LightGbmBinaryPredictor).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(TensorFlowTransform).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(ImageLoaderTransform).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(SymSgdClassificationTrainer).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(AutoInference).Assembly);
env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly);
//env.ComponentCatalog.RegisterAssembly(typeof(EnsemblePredictor).Assembly); // ML.Ensemble
using (var ch = host.Start("Executing"))
{
var sw = new System.Diagnostics.Stopwatch();
sw.Start();
try
{
// code, pszIn, and pszOut can be null.
ch.Trace("Checking parameters");
host.CheckParam(penv != null, nameof(penv));
host.CheckParam(penv->messageSink != null, "penv->message");
host.CheckParam(psz != null, nameof(psz));
ch.Trace("Converting graph operands");
var graph = BytesToString(psz);
ch.Trace("Wiring message sink");
var message = MarshalDelegate<MessageSink>(penv->messageSink);
var messageValidator = new MessageValidator(host);
var lk = new object();
Action<IMessageSource, ChannelMessage> listener =
(sender, msg) =>
{
byte[] bs = StringToNullTerminatedBytes(sender.FullName);
string m = messageValidator.Validate(msg);
if (!string.IsNullOrEmpty(m))
{
byte[] bm = StringToNullTerminatedBytes(m);
lock (lk)
{
fixed (byte* ps = bs)
fixed (byte* pm = bm)
message(penv, msg.Kind, (sbyte*)ps, (sbyte*)pm);
}
}
};
env.AddListener(listener);
host.CheckParam(cdata >= 0, nameof(cdata), "must be non-negative");
host.CheckParam(ppdata != null || cdata == 0, nameof(ppdata));
for (int i = 0; i < cdata; i++)
{
var pdata = ppdata[i];
host.CheckParam(pdata != null, "pdata");
host.CheckParam(0 <= pdata->ccol && pdata->ccol <= int.MaxValue, "ccol");
host.CheckParam(0 <= pdata->crow && pdata->crow <= long.MaxValue, "crow");
if (pdata->ccol > 0)
{
host.CheckParam(pdata->names != null, "names");
host.CheckParam(pdata->kinds != null, "kinds");
host.CheckParam(pdata->keyCards != null, "keyCards");
host.CheckParam(pdata->vecCards != null, "vecCards");
host.CheckParam(pdata->getters != null, "getters");
}
}
ch.Trace("Validating number of data sources");
// Wrap the data sets.
ch.Trace("Wrapping native data sources");
ch.Trace("Executing");
ExecCore(penv, host, ch, graph, cdata, ppdata);
}
catch (Exception e)
{
// Dump the exception chain.
var ex = e;
while (ex.InnerException != null)
ex = ex.InnerException;
ch.Error("*** {1}: '{0}'", ex.Message, ex.GetType());
return -1;
}
finally
{
sw.Stop();
if (penv != null && penv->verbosity > 0)
ch.Info("Elapsed time: {0}", sw.Elapsed);
else
ch.Trace("Elapsed time: {0}", sw.Elapsed);
}
ch.Done();
}
}
return 0;
}
private static void CheckModel(IHost host, byte** ppModelBin, long* pllModelBinLen, int i)
{
host.CheckParam(
ppModelBin != null && ppModelBin[i] != null
&& pllModelBinLen != null && pllModelBinLen[i] > 0, "pModelBin", "Model is missing");
}
private static void ExecCore(EnvironmentBlock* penv, IHost host, IChannel ch, string graph, int cdata, DataSourceBlock** ppdata)
{
Contracts.AssertValue(ch);
ch.AssertValue(host);
ch.AssertNonEmpty(graph);
ch.Assert(cdata >= 0);
ch.Assert(ppdata != null || cdata == 0);
RunGraphCore(penv, host, graph, cdata, ppdata);
}
/// <summary>
/// Convert UTF8 bytes with known length to ROM<char>. Negative length unsupported.
/// </summary>
internal static void BytesToText(sbyte* prgch, ulong bch, ref ReadOnlyMemory<char> dst)
{
if (bch > 0)
dst = BytesToString(prgch, bch).AsMemory();
else
dst = ReadOnlyMemory<char>.Empty;
}
/// <summary>
/// Convert null-terminated UTF8 bytes to ROM<char>. Null pointer unsupported.
/// </summary>
internal static void BytesToText(sbyte* psz, ref ReadOnlyMemory<char> dst)
{
if (psz != null)
dst = BytesToString(psz).AsMemory();
else
dst = ReadOnlyMemory<char>.Empty;
}
/// <summary>
/// Convert UTF8 bytes with known positive length to a string.
/// </summary>
internal static string BytesToString(sbyte* prgch, ulong bch)
{
Contracts.Assert(prgch != null);
Contracts.Assert(bch > 0);
return Encoding.UTF8.GetString((byte*)prgch, (int)bch);
}
/// <summary>
/// Convert null-terminated UTF8 bytes to a string.
/// </summary>
internal static string BytesToString(sbyte* psz)
{
Contracts.Assert(psz != null);
// REVIEW: Ideally should make this safer by always knowing the length.
int cch = 0;
while (psz[cch] != 0)
cch++;
if (cch == 0)
return null;
#if CORECLR
return Encoding.UTF8.GetString((byte*)psz, cch);
#else
if (cch <= 0)
return "";
var decoder = Encoding.UTF8.GetDecoder();
var chars = new char[decoder.GetCharCount((byte*)psz, cch, true)];
int bytesUsed;
int charsUsed;
bool complete;
fixed (char* pchars = chars)
decoder.Convert((byte*)psz, cch, pchars, chars.Length, true, out bytesUsed, out charsUsed, out complete);
Contracts.Assert(bytesUsed == cch);
Contracts.Assert(charsUsed == chars.Length);
Contracts.Assert(complete);
return new string(chars);
#endif
}
/// <summary>
/// Convert a string to null-terminated UTF8 bytes.
/// </summary>
internal static byte[] StringToNullTerminatedBytes(string str)
{
// Note that it will result in multiple UTF-8 null bytes at the end, which is not ideal but harmless.
return Encoding.UTF8.GetBytes(str + Char.MinValue);
}
}
}