Skip to content

Commit

Permalink
JIT: Avoid mask<->vector optimization for masks used in unhandled ways (
Browse files Browse the repository at this point in the history
dotnet#110307)

When a local is used as a return buffer it is not address exposed, so
the address-exposure check was not sufficient. Add checks for
`LCL_ADDR`, `LCL_FLD` and `STORE_LCL_FLD` to make sure any use of a mask
local that is not converted disqualifies it from participating in the
optimization.

Also avoid doing some work for locals that are not SIMD/mask typed
(common case). Previously we would do some unnecessary hash table
lookups and other things in these cases.
  • Loading branch information
jakobbotsch authored and eduardo-vp committed Dec 4, 2024
1 parent fdab5df commit c240068
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 35 deletions.
83 changes: 48 additions & 35 deletions src/coreclr/jit/optimizemaskconversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
public:
enum
{
DoPostOrder = true,
UseExecutionOrder = true
DoPreOrder = true,
UseExecutionOrder = true,
DoLclVarsOnly = true,
};

MaskConversionsCheckVisitor(Compiler* compiler, weight_t bbWeight, MaskConversionsWeightTable* weightsTable)
Expand All @@ -129,16 +130,29 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
{
}

Compiler::fgWalkResult PostOrderVisit(GenTree** use, GenTree* user)
Compiler::fgWalkResult PreOrderVisit(GenTree** use, GenTree* user)
{
GenTreeLclVarCommon* lclOp = (*use)->AsLclVarCommon();
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp);

if (!varTypeIsSIMDOrMask(varDsc))
{
return fgWalkResult::WALK_CONTINUE;
}

// Get the existing weighting (if any).
MaskConversionsWeight* weight = weightsTable->LookupPointerOrAdd(lclOp->GetLclNum(), MaskConversionsWeight());

JITDUMP("%s V%02d at [%06u] ", GenTree::OpName(lclOp->gtOper), lclOp->GetLclNum(),
m_compiler->dspTreeID(lclOp));

GenTreeHWIntrinsic* convertOp = nullptr;

bool isLocalStore = false;
bool isLocalUse = false;
bool isInvalid = false;
bool hasConversion = false;

switch ((*use)->OperGet())
switch (lclOp->OperGet())
{
case GT_STORE_LCL_VAR:
{
Expand All @@ -147,9 +161,9 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
// Look for:
// use:STORE_LCL_VAR(ConvertMaskToVector(x))

if ((*use)->AsLclVar()->Data()->OperIsConvertMaskToVector())
if (lclOp->Data()->OperIsConvertMaskToVector())
{
convertOp = (*use)->AsLclVar()->Data()->AsHWIntrinsic();
convertOp = lclOp->Data()->AsHWIntrinsic();
hasConversion = true;
}
break;
Expand All @@ -164,7 +178,7 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
// -or-
// user: ConditionalSelect(use:LCL_VAR(x), y, z)

if (user->OperIsHWIntrinsic())
if ((user != nullptr) && user->OperIsHWIntrinsic())
{
GenTreeHWIntrinsic* hwintrin = user->AsHWIntrinsic();
NamedIntrinsic ni = hwintrin->GetHWIntrinsicId();
Expand All @@ -186,7 +200,7 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
// emit `vblendmps zmm1 {k1}, zmm2, zmm3` instead of containing the CndSel
// as part of something like `vaddps zmm1 {k1}, zmm2, zmm3`

if (hwintrin->Op(1) == (*use))
if (hwintrin->Op(1) == lclOp)
{
convertOp = user->AsHWIntrinsic();
hasConversion = true;
Expand All @@ -197,25 +211,19 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
}

default:
break;
// LCL_ADDR (can show up unexposed due to retbufs), or partial
// use/store. We do not handle these.
weight->InvalidateWeight();
JITDUMP("is unhandled. ");
return fgWalkResult::WALK_CONTINUE;
}

if (isLocalStore || isLocalUse)
{
GenTreeLclVarCommon* lclOp = (*use)->AsLclVarCommon();
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp->GetLclNum());

// Get the existing weighting (if any).
MaskConversionsWeight defaultWeight;
MaskConversionsWeight* weight = weightsTable->LookupPointerOrAdd(lclOp->GetLclNum(), defaultWeight);

JITDUMP("Local %s V%02d at [%06u] ", isLocalStore ? "store" : "use", lclOp->GetLclNum(),
m_compiler->dspTreeID(lclOp));

// Cannot convert any locals with an exposed address.
if (varDsc->IsAddressExposed())
{
JITDUMP("is address exposed elsewhere. ");
JITDUMP("is address exposed. ");
weight->InvalidateWeight();
return fgWalkResult::WALK_CONTINUE;
}
Expand Down Expand Up @@ -345,29 +353,34 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
assert(lclOp != nullptr);

// Get the existing weighting.
MaskConversionsWeight weight;
bool found = weightsTable->Lookup(lclOp->GetLclNum(), &weight);
assert(found);
MaskConversionsWeight* weight = weightsTable->LookupPointer(lclOp->GetLclNum());

if (weight == nullptr)
{
return fgWalkResult::WALK_CONTINUE;
}

// Quit if the cost of changing is higher or is invalid.
if (weight.currentCost <= weight.switchCost || weight.invalid)
if (weight->currentCost <= weight->switchCost || weight->invalid)
{
JITDUMP("Local %s V%02d at [%06u] will not be converted. ", isLocalStore ? "store" : "use",
lclOp->GetLclNum(), Compiler::dspTreeID(lclOp));
weight.DumpTotalWeight();
weight->DumpTotalWeight();
return fgWalkResult::WALK_CONTINUE;
}

JITDUMP("Local %s V%02d at [%06u] will be converted. ", isLocalStore ? "store" : "use", lclOp->GetLclNum(),
Compiler::dspTreeID(lclOp));
weight.DumpTotalWeight();
weight->DumpTotalWeight();

// Fix up the type of the lcl and the lclvar.
assert(lclOp->gtType != TYP_MASK);
var_types lclOrigType = lclOp->gtType;
lclOp->gtType = TYP_MASK;
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp->GetLclNum());
varDsc->lvType = TYP_MASK;

LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp->GetLclNum());
assert(varTypeIsSIMDOrMask(varDsc));
varDsc->lvType = TYP_MASK;

// Add or remove a conversion

Expand All @@ -390,9 +403,9 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions

// There is not enough information in the lcl to get simd types. Instead reuse the cached
// simd types from the removed convert nodes.
assert(weight.simdBaseJitType != CORINFO_TYPE_UNDEF);
lclOp->Data() = m_compiler->gtNewSimdCvtVectorToMaskNode(TYP_MASK, lclOp->Data(), weight.simdBaseJitType,
weight.simdSize);
assert(weight->simdBaseJitType != CORINFO_TYPE_UNDEF);
lclOp->Data() = m_compiler->gtNewSimdCvtVectorToMaskNode(TYP_MASK, lclOp->Data(), weight->simdBaseJitType,
weight->simdSize);
}

else if (isLocalUse && removeConversion)
Expand All @@ -414,9 +427,9 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions

// There is not enough information in the lcl to get simd types. Instead reuse the cached simd
// types from the removed convert nodes.
assert(weight.simdBaseJitType != CORINFO_TYPE_UNDEF);
assert(weight->simdBaseJitType != CORINFO_TYPE_UNDEF);
*use =
m_compiler->gtNewSimdCvtMaskToVectorNode(lclOrigType, lclOp, weight.simdBaseJitType, weight.simdSize);
m_compiler->gtNewSimdCvtMaskToVectorNode(lclOrigType, lclOp, weight->simdBaseJitType, weight->simdSize);
}

JITDUMP("Updated %s V%02d at [%06u] to mask (%s conversion)\n", isLocalStore ? "store" : "use",
Expand Down Expand Up @@ -521,7 +534,7 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
// Only check statements where there is a local of type TYP_SIMD/TYP_MASK.
for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
{
if (varTypeIsSIMDOrMask(lcl))
if (varTypeIsSIMDOrMask(lvaGetDesc(lcl)))
{
// Parse the entire statement.
MaskConversionsCheckVisitor ev(this, block->getBBWeight(this), &weightsTable);
Expand Down
74 changes: 74 additions & 0 deletions src/tests/JIT/Regression/JitBlue/Runtime_110306/Runtime_110306.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

// Generated by Fuzzlyn v2.4 on 2024-12-01 16:32:26
// Run on X64 Linux
// Seed: 7861295224295601455-vectort,vector128,vector256,x86aes,x86avx,x86avx2,x86avx512bw,x86avx512bwvl,x86avx512cd,x86avx512cdvl,x86avx512dq,x86avx512dqvl,x86avx512f,x86avx512fvl,x86avx512fx64,x86bmi1,x86bmi1x64,x86bmi2,x86bmi2x64,x86fma,x86lzcnt,x86lzcntx64,x86pclmulqdq,x86popcnt,x86popcntx64,x86sse,x86ssex64,x86sse2,x86sse2x64,x86sse3,x86sse41,x86sse41x64,x86sse42,x86sse42x64,x86ssse3,x86x86base
// Reduced from 115.8 KiB to 0.9 KiB in 00:02:27
// Hits JIT assert in Release:
// Assertion failed 'newLclValue.BothDefined()' in 'Program:Main(Fuzzlyn.ExecutionServer.IRuntime)' during 'Do value numbering' (IL size 61; hash 0xade6b36b; FullOpts)
//
// File: /__w/1/s/src/coreclr/jit/valuenum.cpp Line: 6138
//
using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using Xunit;

public class C0
{
public uint F0;
}

public struct S0
{
public C0 F2;
}

public class C3
{
public byte F0;
}

public class Runtime_110306
{
public static S0 s_3;

[Fact]
public static void TestEntryPoint()
{
if (!Avx512F.VL.IsSupported)
{
return;
}

try
{
TestMain();
}
catch
{
}
}

private static void TestMain()
{
var vr5 = Vector256.Create(1, 0, 0, 0);
Vector256<long> vr15 = Vector256.Create<long>(0);
Vector256<long> vr8 = Avx512F.VL.CompareNotEqual(vr5, vr15);
long vr9 = 0;
var vr10 = new C3();
vr8 = M3();
long vr11 = vr9;
var vr12 = s_3.F2.F0;
vr8 = vr8;
}

[MethodImpl(MethodImplOptions.NoInlining)]
public static Vector256<long> M3()
{
return default;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Optimize>True</Optimize>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildProjectName).cs" />
</ItemGroup>
</Project>

0 comments on commit c240068

Please sign in to comment.