Skip to content

Commit

Permalink
Provide proper calling conversion for x86 framework (#1833)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivanidzo4ka authored Dec 7, 2018
1 parent 284e02c commit 2c87b19
Show file tree
Hide file tree
Showing 16 changed files with 161 additions and 138 deletions.
72 changes: 48 additions & 24 deletions src/Microsoft.ML.FastTree/Dataset/DenseIntArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;

namespace Microsoft.ML.Trainers.FastTree.Internal
{
Expand Down Expand Up @@ -70,13 +71,14 @@ public override IntArray[] Split(int[][] assignment)
}

#if USE_FASTTREENATIVE
[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)]
internal const string NativePath = "FastTreeNative";
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern unsafe int C_Sumup_float(
int numBits, byte* pData, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights,
FloatType* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)]
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern unsafe int C_Sumup_double(
int numBits, byte* pData, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights,
FloatType* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
Expand Down Expand Up @@ -154,7 +156,8 @@ public Dense0BitIntArray(byte[] buffer, ref int position)
{
}

public override MD5Hash MD5Hash {
public override MD5Hash MD5Hash
{
get { return MD5Hasher.Hash(Length); }
}

Expand All @@ -178,13 +181,16 @@ public override void ToByteArray(byte[] buffer, ref int position)
Length.ToByteArray(buffer, ref position);
}

public override int this[int index] {
get {
public override int this[int index]
{
get
{
Contracts.Assert(0 <= index && index < Length);
return 0;
}

set {
set
{
Contracts.Assert(0 <= index && index < Length);
Contracts.Assert(value == 0);
}
Expand Down Expand Up @@ -266,7 +272,8 @@ private void Set(long offset, uint mask, int value)
_data[major + 1] = (_data[major + 1] & ~major1Mask) | (uint)(val >> 32);
}

public override MD5Hash MD5Hash {
public override MD5Hash MD5Hash
{
get { return MD5Hasher.Hash(_data); }
}

Expand All @@ -291,8 +298,10 @@ public override void ToByteArray(byte[] buffer, ref int position)
_data.ToByteArray(buffer, ref position);
}

public sealed override unsafe int this[int index] {
get {
public sealed override unsafe int this[int index]
{
get
{
long offset = index;
offset = (offset << 3) + (offset << 1);
int minor = (int)(offset & 0x1f);
Expand All @@ -301,7 +310,8 @@ public sealed override unsafe int this[int index] {
return (int)(((*(ulong*)(pData + major)) >> minor) & _mask);
}

set {
set
{
Contracts.Assert(0 <= value && value < (1 << 10));
Set(((long)index) * 10, _mask, value);
}
Expand Down Expand Up @@ -436,10 +446,12 @@ public override unsafe void Callback(Action<IntPtr> callback)
}
}

public override unsafe int this[int index] {
public override unsafe int this[int index]
{
get { return _data[index]; }

set {
set
{
Contracts.Assert(0 <= value && value <= byte.MaxValue);
_data[index] = (byte)value;
}
Expand Down Expand Up @@ -471,7 +483,8 @@ internal sealed class Dense4BitIntArray : DenseIntArray

public override IntArrayBits BitsPerItem { get { return IntArrayBits.Bits4; } }

public override MD5Hash MD5Hash {
public override MD5Hash MD5Hash
{
get { return MD5Hasher.Hash(_data); }
}

Expand Down Expand Up @@ -532,8 +545,10 @@ public override void ToByteArray(byte[] buffer, ref int position)
_data.ToByteArray(buffer, ref position);
}

public override unsafe int this[int index] {
get {
public override unsafe int this[int index]
{
get
{
int dataIndex = index / 2;
bool highBits = (index % 2 == 0);

Expand All @@ -546,7 +561,8 @@ public override unsafe int this[int index] {
return v;
}

set {
set
{
Contracts.Assert(0 <= value && value < (1 << 4));
byte v;
v = (byte)value;
Expand Down Expand Up @@ -607,7 +623,8 @@ public Dense16BitIntArray(byte[] buffer, ref int position)
_data = buffer.ToUShortArray(ref position);
}

public override MD5Hash MD5Hash {
public override MD5Hash MD5Hash
{
get { return MD5Hasher.Hash(_data); }
}

Expand Down Expand Up @@ -640,12 +657,15 @@ public override void ToByteArray(byte[] buffer, ref int position)
_data.ToByteArray(buffer, ref position);
}

public override unsafe int this[int index] {
get {
public override unsafe int this[int index]
{
get
{
return _data[index];
}

set {
set
{
Contracts.Assert(0 <= value && value <= ushort.MaxValue);
_data[index] = (ushort)value;
}
Expand Down Expand Up @@ -700,7 +720,8 @@ public override unsafe void Callback(Action<IntPtr> callback)
}
}

public override MD5Hash MD5Hash {
public override MD5Hash MD5Hash
{
get { return MD5Hasher.Hash(_data); }
}

Expand All @@ -725,12 +746,15 @@ public override void ToByteArray(byte[] buffer, ref int position)
_data.ToByteArray(buffer, ref position);
}

public override int this[int index] {
get {
public override int this[int index]
{
get
{
return _data[index];
}

set {
set
{
Contracts.Assert(value >= 0);
_data[index] = value;
}
Expand Down
17 changes: 9 additions & 8 deletions src/Microsoft.ML.FastTree/Dataset/SegmentIntArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;

namespace Microsoft.ML.Trainers.FastTree.Internal
{
Expand Down Expand Up @@ -489,31 +490,31 @@ public static unsafe void SegmentFindOptimalCost31(uint[] array, int len, out lo
}
bits = b;
}

internal const string NativePath = "FastTreeNative";
#pragma warning disable TLC_GeneralName // Externs follow their own rules.
[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
private static extern unsafe void C_SegmentFindOptimalPath21(uint* valv, int valc, long* pBits, int* pTransitions);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
private static extern unsafe void C_SegmentFindOptimalPath15(uint* valv, int valc, long* pBits, int* pTransitions);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
private static extern unsafe void C_SegmentFindOptimalPath7(uint* valv, int valc, long* pBits, int* pTransitions);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
private static extern unsafe void C_SegmentFindOptimalCost15(uint* valv, int valc, long* pBits);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
[DllImport(NativePath, CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
private static extern unsafe void C_SegmentFindOptimalCost31(uint* valv, int valc, long* pBits);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)]
[DllImport(NativePath)]
private static extern unsafe int C_SumupSegment_float(
uint* pData, byte* pSegType, int* pSegLength, int* pIndices,
float* pSampleOutputs, double* pSampleOutputWeights,
float* pSumTargetsByBin, double* pSumWeightsByBin,
int* pCountByBin, int totalCount, double totalSampleOutputs);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)]
[DllImport(NativePath)]
private static extern unsafe int C_SumupSegment_double(
uint* pData, byte* pSegType, int* pSegLength, int* pIndices,
double* pSampleOutputs, double* pSampleOutputWeights,
Expand Down
6 changes: 4 additions & 2 deletions src/Microsoft.ML.FastTree/Dataset/SparseIntArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;

namespace Microsoft.ML.Trainers.FastTree.Internal
{
Expand Down Expand Up @@ -490,12 +491,13 @@ public override void Sumup(SumupInputData input, FeatureHistogram histogram)
}

#if USE_FASTTREENATIVE
[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)]
internal const string NativePath = "FastTreeNative";
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern unsafe int C_SumupDeltaSparse_float(int numBits, byte* pValues, byte* pDeltas, int numDeltas, int* pIndices, float* pSampleOutputs, double* pSampleOutputWeights,
float* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);

[DllImport("FastTreeNative", CallingConvention = CallingConvention.StdCall)]
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern unsafe int C_SumupDeltaSparse_double(int numBits, byte* pValues, byte* pDeltas, int numDeltas, int* pIndices, double* pSampleOutputs, double* pSampleOutputWeights,
double* pSumTargetsByBin, double* pSumTargets2ByBin, int* pCountByBin,
int totalCount, double totalSampleOutputs, double totalSampleOutputWeights);
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;

// REVIEW: Do we really need all these names?
Expand Down Expand Up @@ -1090,7 +1091,7 @@ private static void PermutationSort(int[] permutation, double[] scores, short[]
}));
}

[DllImport("FastTreeNative", EntryPoint = "C_GetDerivatives", CallingConvention = CallingConvention.StdCall, CharSet = CharSet.Ansi)]
[DllImport("FastTreeNative", EntryPoint = "C_GetDerivatives", CharSet = CharSet.Ansi), SuppressUnmanagedCodeSecurity]
private static extern unsafe void GetDerivatives(
int numDocuments, int begin, int* pPermutation, short* pLabels,
double* pScores, double* pLambdas, double* pWeights, double* pDiscount,
Expand Down
9 changes: 5 additions & 4 deletions src/Microsoft.ML.HalLearners/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Security;

[assembly: LoadableClass(OlsLinearRegressionTrainer.Summary, typeof(OlsLinearRegressionTrainer), typeof(OlsLinearRegressionTrainer.Arguments),
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
Expand Down Expand Up @@ -380,7 +381,7 @@ private OlsLinearRegressionPredictor TrainCore(IChannel ch, FloatLabelCursor.Fac

internal static class Mkl
{
private const string DllName = "MklImports";
private const string MklPath = "MklImports";

public enum Layout
{
Expand All @@ -394,7 +395,7 @@ public enum UpLo : byte
Lo = (byte)'L'
}

[DllImport(DllName, EntryPoint = "LAPACKE_dpptrf")]
[DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_dpptrf"), SuppressUnmanagedCodeSecurity]
private static extern int PptrfInternal(Layout layout, UpLo uplo, int n, Double[] ap);

/// <summary>
Expand Down Expand Up @@ -429,7 +430,7 @@ public static void Pptrf(Layout layout, UpLo uplo, int n, Double[] ap)
}
}

[DllImport(DllName, EntryPoint = "LAPACKE_dpptrs")]
[DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_dpptrs"), SuppressUnmanagedCodeSecurity]
private static extern int PptrsInternal(Layout layout, UpLo uplo, int n, int nrhs, Double[] ap, Double[] b, int ldb);

/// <summary>
Expand Down Expand Up @@ -476,7 +477,7 @@ public static void Pptrs(Layout layout, UpLo uplo, int n, int nrhs, Double[] ap,

}

[DllImport(DllName, EntryPoint = "LAPACKE_dpptri")]
[DllImport(MklPath, CallingConvention = CallingConvention.Cdecl, EntryPoint = "LAPACKE_dpptri"), SuppressUnmanagedCodeSecurity]
private static extern int PptriInternal(Layout layout, UpLo uplo, int n, Double[] ap);

/// <summary>
Expand Down
13 changes: 6 additions & 7 deletions src/Microsoft.ML.HalLearners/SymSgdClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,9 @@ private static unsafe class Native
//To triger the loading of MKL library since SymSGD native library depends on it.
static Native() => ErrorMessage(0);

internal const string DllName = "SymSgdNative";

[DllImport(DllName), SuppressUnmanagedCodeSecurity]
internal const string NativePath = "SymSgdNative";
internal const string MklPath = "MklImports";
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern void LearnAll(int totalNumInstances, int* instSizes, int** instIndices,
float** instValues, float* labels, bool tuneLR, ref float lr, float l2Const, float piw, float* weightVector, ref float bias,
int numFeatres, int numPasses, int numThreads, bool tuneNumLocIter, ref int numLocIter, float tolerance, bool needShuffle, bool shouldInitialize, State* state);
Expand Down Expand Up @@ -833,7 +833,7 @@ public static void LearnAll(InputDataManager inputDataManager, bool tuneLR,
}
}

[DllImport(DllName), SuppressUnmanagedCodeSecurity]
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern void MapBackWeightVector(float* weightVector, State* state);

/// <summary>
Expand All @@ -847,7 +847,7 @@ public static void MapBackWeightVector(Span<float> weightVector, GCHandle stateG
MapBackWeightVector(pweightVector, (State*)stateGCHandle.AddrOfPinnedObject());
}

[DllImport(DllName), SuppressUnmanagedCodeSecurity]
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
private static extern void DeallocateSequentially(State* state);

public static void DeallocateSequentially(GCHandle stateGCHandle)
Expand All @@ -856,8 +856,7 @@ public static void DeallocateSequentially(GCHandle stateGCHandle)
}

// See: https://software.intel.com/en-us/node/521990
[System.Security.SuppressUnmanagedCodeSecurity]
[DllImport("MklImports", EntryPoint = "DftiErrorMessage", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Auto)]
[DllImport(MklPath, EntryPoint = "DftiErrorMessage", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Auto), SuppressUnmanagedCodeSecurity]
private static extern IntPtr ErrorMessage(int status);
}

Expand Down
Loading

0 comments on commit 2c87b19

Please sign in to comment.