Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CpuMath Enhancement: Make bound checking of loops in hardware intrinsics more efficient #2939

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,11 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
float* pDstEnd = pdst + dst.Length;
float* pDstCurrent = pdst;
float* pVectorizationEnd = pDstEnd - 4;
float* pAvxVectorizationEnd = pDstEnd - 8;

Vector256<float> scalarVector256 = Vector256.Create(scalar);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pAvxVectorizationEnd)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tannergooding, do you know if there's any instruction with corresponds roughly to addfloats xmm ptr [rax], xmm0 and if this trio of instructions collapses to that? It seems like it'd be more efficient and would avoid the xmm register spill.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is only load-forms, so this will compile down to two instructions (ideally):

addps tmp, scalarVector128, [pDstCurrent]
movps [pDstCurrent], tmp

dstVector = Avx.Add(dstVector, scalarVector256);
Expand Down Expand Up @@ -574,10 +575,11 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<f
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pVectorizationEnd = pDstEnd - 4;
float* pAvxVectorizationEnd = pDstEnd - 8;

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pAvxVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Multiply(srcVector, scaleVector256);
Expand Down Expand Up @@ -619,11 +621,12 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
float* pDstEnd = pdst + dst.Length;
float* pDstCurrent = pdst;
float* pVectorizationEnd = pDstEnd - 4;
float* pAvxVectorizationEnd = pDstEnd - 8;

Vector256<float> a256 = Vector256.Create(a);
Vector256<float> b256 = Vector256.Create(b);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pAvxVectorizationEnd)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
dstVector = Avx.Add(dstVector, b256);
Expand Down Expand Up @@ -668,10 +671,11 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pEnd = pdst + count;
float* pVectorizationEnd = pEnd - 8;

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pVectorizationEnd)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);

Expand Down Expand Up @@ -722,13 +726,14 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re
fixed (float* pres = &MemoryMarshal.GetReference(result))
{
float* pResEnd = pres + count;
float* pVectorizationEnd = pResEnd - 8;
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pResCurrent = pres;

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pResCurrent + 8 <= pResEnd)
while (pResCurrent <= pVectorizationEnd)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Expand Down Expand Up @@ -782,10 +787,11 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO
int* pIdxCurrent = pidx;
float* pDstCurrent = pdst;
int* pEnd = pidx + count;
int* pVectorizationEnd = pEnd - 8;

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pIdxCurrent + 8 <= pEnd)
while (pIdxCurrent <= pVectorizationEnd)
{
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Expand Down Expand Up @@ -830,8 +836,9 @@ public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int cou
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pEnd = psrc + count;
float* pVectorizationEnd = pEnd - 8;

while (pSrcCurrent + 8 <= pEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -882,8 +889,9 @@ public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx,
int* pIdxCurrent = pidx;
float* pDstCurrent = pdst;
int* pEnd = pidx + count;
int* pVectorizationEnd = pEnd - 8;

while (pIdxCurrent + 8 <= pEnd)
while (pIdxCurrent <= pVectorizationEnd)
{
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Expand Down Expand Up @@ -930,8 +938,9 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
float* pSrc2Current = psrc2;
float* pDstCurrent = pdst;
float* pEnd = pdst + count;
float* pVectorizationEnd = pEnd - 8;

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pVectorizationEnd)
{
Vector256<float> src1Vector = Avx.LoadVector256(pSrc1Current);
Vector256<float> src2Vector = Avx.LoadVector256(pSrc2Current);
Expand Down Expand Up @@ -1062,11 +1071,12 @@ public static unsafe float SumSqU(ReadOnlySpan<float> src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = MultiplyAdd(srcVector, srcVector, result256);
Expand Down Expand Up @@ -1106,12 +1116,13 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;

Vector256<float> result256 = Vector256<float>.Zero;
Vector256<float> meanVector256 = Vector256.Create(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1154,11 +1165,12 @@ public static unsafe float SumAbsU(ReadOnlySpan<float> src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256));
Expand Down Expand Up @@ -1198,12 +1210,13 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;

Vector256<float> result256 = Vector256<float>.Zero;
Vector256<float> meanVector256 = Vector256.Create(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1247,11 +1260,12 @@ public static unsafe float MaxAbsU(ReadOnlySpan<float> src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256));
Expand Down Expand Up @@ -1291,12 +1305,13 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
fixed (float* psrc = &MemoryMarshal.GetReference(src))
{
float* pSrcEnd = psrc + src.Length;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;

Vector256<float> result256 = Vector256<float>.Zero;
Vector256<float> meanVector256 = Vector256.Create(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1345,10 +1360,11 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pSrcEnd = psrc + count;
float* pVectorizationEnd = pSrcEnd - 8;

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
result256 = MultiplyAdd(pSrcCurrent, dstVector, result256);
Expand Down Expand Up @@ -1402,10 +1418,11 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds
float* pDstCurrent = pdst;
int* pIdxCurrent = pidx;
int* pIdxEnd = pidx + count;
int* pVectorizationEnd = pIdxEnd - 8;

Vector256<float> result256 = Vector256<float>.Zero;

while (pIdxCurrent + 8 <= pIdxEnd)
while (pIdxCurrent <= pVectorizationEnd)
{
Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent);
result256 = MultiplyAdd(pDstCurrent, srcVector, result256);
Expand Down Expand Up @@ -1456,10 +1473,11 @@ public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> ds
float* pSrcCurrent = psrc;
float* pDstCurrent = pdst;
float* pSrcEnd = psrc + count;
float* pVectorizationEnd = pSrcEnd - 8;

Vector256<float> sqDistanceVector256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent),
Avx.LoadVector256(pDstCurrent));
Expand Down Expand Up @@ -1507,14 +1525,15 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
{
float* pSrcEnd = psrc + count;
float* pVectorizationEnd = pSrcEnd - 8;
float* pSrcCurrent = psrc;
float* pDst1Current = pdst1;
float* pDst2Current = pdst2;

Vector256<float> xPrimal256 = Vector256.Create(primalUpdate);
Vector256<float> xThreshold256 = Vector256.Create(threshold);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pVectorizationEnd)
{
Vector256<float> xDst1 = Avx.LoadVector256(pDst1Current);
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
Expand Down Expand Up @@ -1568,13 +1587,14 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly
fixed (float* pdst2 = &MemoryMarshal.GetReference(w))
{
int* pIdxEnd = pidx + count;
int* pVectorizationEnd = pIdxEnd - 8;
float* pSrcCurrent = psrc;
int* pIdxCurrent = pidx;

Vector256<float> xPrimal256 = Vector256.Create(primalUpdate);
Vector256<float> xThreshold = Vector256.Create(threshold);

while (pIdxCurrent + 8 <= pIdxEnd)
while (pIdxCurrent <= pVectorizationEnd)
{
Vector256<float> xDst1 = Load8(pdst1, pIdxCurrent);
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
Expand Down
Loading