Skip to content

Commit

Permalink
Ensure that Arm64 correctly handles multiplication of simd by a 64-bi…
Browse files Browse the repository at this point in the history
…t scalar
  • Loading branch information
tannergooding committed Aug 23, 2024
1 parent 0a31fd7 commit de8b16a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 22 deletions.
42 changes: 20 additions & 22 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20830,21 +20830,14 @@ GenTree* Compiler::gtNewSimdBinOpNode(
{
GenTree** broadcastOp = nullptr;

#if defined(TARGET_ARM64)
if (varTypeIsLong(simdBaseType))
{
break;
}
#endif // TARGET_ARM64

if (varTypeIsArithmetic(op1))
{
broadcastOp = &op1;

#if defined(TARGET_ARM64)
if (!varTypeIsByte(simdBaseType))
{
// MultiplyByScalar requires the scalar op to be op2fGetHWIntrinsicIdForBinOp
// MultiplyByScalar requires the scalar op to be op2 for GetHWIntrinsicIdForBinOp
needsReverseOps = true;
}
#endif // TARGET_ARM64
Expand All @@ -20857,7 +20850,12 @@ GenTree* Compiler::gtNewSimdBinOpNode(
if (broadcastOp != nullptr)
{
#if defined(TARGET_ARM64)
if (!varTypeIsByte(simdBaseType))
if (varTypeIsLong(simdBaseType))
{
// This is handled via emulation and the scalar is consumed directly
break;
}
else if (!varTypeIsByte(simdBaseType))
{
op2ForLookup = *broadcastOp;
*broadcastOp = gtNewSimdCreateScalarUnsafeNode(TYP_SIMD8, *broadcastOp, simdBaseJitType, 8);
Expand Down Expand Up @@ -21261,24 +21259,26 @@ GenTree* Compiler::gtNewSimdBinOpNode(
#elif defined(TARGET_ARM64)
if (varTypeIsLong(simdBaseType))
{
GenTree** op1ToDup = &op1;
GenTree** op2ToDup = &op2;
GenTree** op2ToDup = nullptr;

if (!varTypeIsArithmetic(op1))
{
op1 = gtNewSimdToScalarNode(TYP_LONG, op1, simdBaseJitType, simdSize);
op1ToDup = &op1->AsHWIntrinsic()->Op(1);
}
assert(varTypeIsSIMD(op1));
op1 = gtNewSimdToScalarNode(TYP_LONG, op1, simdBaseJitType, simdSize);
GenTree** op1ToDup = &op1->AsHWIntrinsic()->Op(1);

if (!varTypeIsArithmetic(op2))
if (varTypeIsSIMD(op2))
{
op2 = gtNewSimdToScalarNode(TYP_LONG, op2, simdBaseJitType, simdSize);
op2ToDup = &op2->AsHWIntrinsic()->Op(1);
}

// lower = op1.GetElement(0) * op2.GetElement(0)
GenTree* lower = gtNewOperNode(GT_MUL, TYP_LONG, op1, op2);
lower = gtNewSimdCreateScalarUnsafeNode(type, lower, simdBaseJitType, simdSize);

if (op2ToDup == nullptr)
{
op2ToDup = &lower->AsOp()->gtOp2;
}
lower = gtNewSimdCreateScalarUnsafeNode(type, lower, simdBaseJitType, simdSize);

if (simdSize == 8)
{
Expand All @@ -21290,10 +21290,8 @@ GenTree* Compiler::gtNewSimdBinOpNode(
GenTree* op1Dup = fgMakeMultiUse(op1ToDup);
GenTree* op2Dup = fgMakeMultiUse(op2ToDup);

if (!varTypeIsArithmetic(op1Dup))
{
op1Dup = gtNewSimdGetElementNode(TYP_LONG, op1Dup, gtNewIconNode(1), simdBaseJitType, simdSize);
}
assert(!varTypeIsArithmetic(op1Dup));
op1Dup = gtNewSimdGetElementNode(TYP_LONG, op1Dup, gtNewIconNode(1), simdBaseJitType, simdSize);

if (!varTypeIsArithmetic(op2Dup))
{
Expand Down
20 changes: 20 additions & 0 deletions src/tests/JIT/Regression/JitBlue/Runtime_106838/Runtime_106838.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using Xunit;

public class Runtime_106838
{
[MethodImpl(MethodImplOptions.NoInlining)]
private static Vector128<ulong> Problem(Vector128<ulong> vector) => vector * 5UL;

[Fact]
public static void TestEntryPoint()
{
Vector128<ulong> result = Problem(Vector128.Create<ulong>(5));
Assert.Equal(Vector128.Create<ulong>(25), result);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Optimize>True</Optimize>
</PropertyGroup>
<ItemGroup>
<Compile Include="$(MSBuildProjectName).cs" />
</ItemGroup>
</Project>

0 comments on commit de8b16a

Please sign in to comment.