Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Optimize Vector128<float>/<double>.Equals in interp/jiterp #88064

Merged
merged 3 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,16 @@ public static Vector128<T> operator >>>(Vector128<T> value, int shiftCount)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals([NotNullWhen(true)] object? obj) => (obj is Vector128<T> other) && Equals(other);

// Account for floating-point equality around NaN
// This is in a separate method so it can be optimized by the mono interpreter/jiterpreter
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool EqualsFloatingPoint (Vector128<T> lhs, Vector128<T> rhs)
{
Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
return result.AsInt32() == Vector128<int>.AllBitsSet;
}
Comment on lines +395 to +399
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just handle Equals directly? There are several APIs on Vector2/3/4 and Vector<T> that are [Intrinsic] instance methods, so I would expect Mono has the support for handling that already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now we already vectorize the underlying Equals operation, the problem is that this one desugars (?) to a bunch of individual SIMD operations that each get their own interp opcode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, to be clear, if you call Vector128<float>.Equals(...) the interp generates something like this right now:

v128_equals_r4 (lhs, rhs)
v128_equals_r4 (lhs, lhs)
v128_equals_r4 (rhs, rhs)
v128_or stack
v128_not stack
v128_or stack
v128_load_allbitsset
v128_equals_i4 stack
v128_all_true

Copy link
Member

@tannergooding tannergooding Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant why not mark the instance Equals method as [Intrinsic] and have Mono treat that to be the same as operator == for integers and to be a single opcode representing this sequence for float/double

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's what you prefer I can figure out how to do it. I wanted to keep this change as narrow as possible since the existing Equals method is fine for ints as-is. I'll see how much work it is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, why split off a separate EqualsFloatingPoint at all, rather than simply handling its only caller directly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries if its overly complex to do. I just figured it would be better overall for both RyuJIT and Mono.

We actually used to do just what I've proposed in RyuJIT, but dropped that support a while back since the inliner was able to handle it and the instance equals calls were much rarer to encounter.


/// <summary>Determines whether the specified <see cref="Vector128{T}" /> is equal to the current instance.</summary>
/// <param name="other">The <see cref="Vector128{T}" /> to compare with the current instance.</param>
/// <returns><c>true</c> if <paramref name="other" /> is equal to the current instance; otherwise, <c>false</c>.</returns>
Expand All @@ -401,8 +411,7 @@ public bool Equals(Vector128<T> other)
{
if ((typeof(T) == typeof(double)) || (typeof(T) == typeof(float)))
{
Vector128<T> result = Vector128.Equals(this, other) | ~(Vector128.Equals(this, this) | Vector128.Equals(other, other));
return result.AsInt32() == Vector128<int>.AllBitsSet;
return EqualsFloatingPoint(this, other);
}
else
{
Expand Down
3 changes: 3 additions & 0 deletions src/mono/mono/mini/interp/interp-simd-intrins.def
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_OR, interp_v128_o
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY, interp_v128_op_bitwise_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_INEQUALITY, interp_v128_op_bitwise_inequality, -1)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY, interp_v128_r4_float_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY, interp_v128_r8_float_equality, -1)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR, interp_v128_op_exclusive_or, 81)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_I1_MULTIPLY, interp_v128_i1_op_multiply, -1)
Expand Down
27 changes: 26 additions & 1 deletion src/mono/mono/mini/interp/interp-simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ typedef guint16 v128_u2 __attribute__ ((vector_size (SIZEOF_V128)));
typedef gint8 v128_i1 __attribute__ ((vector_size (SIZEOF_V128)));
typedef guint8 v128_u1 __attribute__ ((vector_size (SIZEOF_V128)));
typedef float v128_r4 __attribute__ ((vector_size (SIZEOF_V128)));
typedef double v128_r8 __attribute__ ((vector_size (SIZEOF_V128)));

// get_AllBitsSet
static void
Expand Down Expand Up @@ -122,7 +123,30 @@ interp_v128_op_bitwise_inequality (gpointer res, gpointer v1, gpointer v2)
*(gint32*)res = 1;
}

// op_Addition
// Vector128<float>EqualsFloatingPoint
static void
interp_v128_r4_float_equality (gpointer res, gpointer v1, gpointer v2)
{
v128_r4 v1_cast = *(v128_r4*)v1;
v128_r4 v2_cast = *(v128_r4*)v2;
v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V128);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
}

static void
interp_v128_r8_float_equality (gpointer res, gpointer v1, gpointer v2)
{
v128_r8 v1_cast = *(v128_r8*)v1;
v128_r8 v2_cast = *(v128_r8*)v2;
v128_r8 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V128);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
}

// op_Multiply
static void
interp_v128_i1_op_multiply (gpointer res, gpointer v1, gpointer v2)
{
Expand All @@ -147,6 +171,7 @@ interp_v128_r4_op_multiply (gpointer res, gpointer v1, gpointer v2)
*(v128_r4*)res = *(v128_r4*)v1 * *(v128_r4*)v2;
}

// op_Division
static void
interp_v128_r4_op_division (gpointer res, gpointer v1, gpointer v2)
{
Expand Down
1 change: 1 addition & 0 deletions src/mono/mono/mini/interp/simd-methods.def
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ SIMD_METHOD(CreateScalar)
SIMD_METHOD(CreateScalarUnsafe)

SIMD_METHOD(Equals)
SIMD_METHOD(EqualsFloatingPoint)
SIMD_METHOD(ExtractMostSignificantBits)
SIMD_METHOD(GreaterThan)
SIMD_METHOD(LessThan)
Expand Down
8 changes: 8 additions & 0 deletions src/mono/mono/mini/interp/transform-simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static guint16 sri_vector128_methods [] = {
};

static guint16 sri_vector128_t_methods [] = {
SN_EqualsFloatingPoint,
SN_get_AllBitsSet,
SN_get_Count,
SN_get_One,
Expand Down Expand Up @@ -196,6 +197,13 @@ emit_common_simd_operations (TransformData *td, int id, int atype, int vector_si
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY;
}
break;
case SN_EqualsFloatingPoint:
*simd_opcode = MINT_SIMD_INTRINS_P_PP;
if (atype == MONO_TYPE_R4)
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY;
else if (atype == MONO_TYPE_R8)
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY;
break;
case SN_op_ExclusiveOr:
*simd_opcode = MINT_SIMD_INTRINS_P_PP;
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR;
Expand Down
38 changes: 38 additions & 0 deletions src/mono/sample/wasm/browser-bench/Vector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public VectorTask()
new MinDouble(),
new MaxDouble(),
new Normalize(),
new EqualsInt32(),
new EqualsFloat(),
};
}

Expand Down Expand Up @@ -344,5 +346,41 @@ public override void RunStep() {
result = vector / (float)Math.Sqrt(Vector128.Dot(vector, vector));
}
}

class EqualsInt32 : VectorMeasurement
{
Vector128<Int32> vector1, vector2;
bool result;

public override string Name => "Equals Int32";

public EqualsInt32()
{
vector1 = Vector128.Create(1, 2, 3, 4);
vector2 = Vector128.Create(4, 3, 2, 1);
}

public override void RunStep() {
result = vector1.Equals(vector2);
}
}

class EqualsFloat : VectorMeasurement
{
Vector128<float> vector1, vector2;
bool result;

public override string Name => "Equals Float";

public EqualsFloat()
{
vector1 = Vector128.Create(1f, 2f, 3f, 4f);
vector2 = Vector128.Create(4f, 3f, 2f, 1f);
}

public override void RunStep() {
result = vector1.Equals(vector2);
}
}
}
}
7 changes: 5 additions & 2 deletions src/mono/wasm/runtime/jiterpreter-support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,10 @@ export class WasmBuilder {
return this.current.appendU8(value);
}

appendSimd(value: WasmSimdOpcode) {
appendSimd(value: WasmSimdOpcode, allowLoad?: boolean) {
this.current.appendU8(WasmOpcode.PREFIX_simd);
// Yes that's right. We're using LEB128 to encode 8-bit opcodes. Why? I don't know
mono_assert(((value | 0) !== 0) || ((value === WasmSimdOpcode.v128_load) && (allowLoad === true)), "Expected non-v128_load simd opcode or allowLoad==true");
return this.current.appendULeb(value);
}

Expand Down Expand Up @@ -993,6 +994,7 @@ export class BlobBuilder {
}

appendULeb(value: number) {
mono_assert(typeof (value) === "number", () => `appendULeb expected number but got ${value}`);
mono_assert(value >= 0, "cannot pass negative value to appendULeb");
if (value < 0x7F) {
if (this.size + 1 >= this.capacity)
Expand All @@ -1013,6 +1015,7 @@ export class BlobBuilder {
}

appendLeb(value: number) {
mono_assert(typeof (value) === "number", () => `appendLeb expected number but got ${value}`);
if (this.size + 8 >= this.capacity)
throw new Error("Buffer full");

Expand Down Expand Up @@ -1721,7 +1724,7 @@ export function try_append_memmove_fast(
while (count >= sizeofV128) {
builder.local(destLocal);
builder.local(srcLocal);
builder.appendSimd(WasmSimdOpcode.v128_load);
builder.appendSimd(WasmSimdOpcode.v128_load, true);
builder.appendMemarg(srcOffset, 0);
builder.appendSimd(WasmSimdOpcode.v128_store);
builder.appendMemarg(destOffset, 0);
Expand Down
31 changes: 30 additions & 1 deletion src/mono/wasm/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,8 @@ function append_ldloc(builder: WasmBuilder, offset: number, opcodeOrPrefix: Wasm
if (simdOpcode !== undefined) {
// This looks wrong but I assure you it's correct.
builder.appendULeb(simdOpcode);
} else if (opcodeOrPrefix === WasmOpcode.PREFIX_simd) {
throw new Error("PREFIX_simd ldloc without a simdOpcode");
}
const alignment = computeMemoryAlignment(offset, opcodeOrPrefix, simdOpcode);
builder.appendMemarg(offset, alignment);
Expand Down Expand Up @@ -3493,7 +3495,7 @@ function emit_simd_2(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
// Indirect load, so v1 is T** and res is Vector128*
builder.local("pLocals");
append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load);
builder.appendSimd(simple);
builder.appendSimd(simple, true);
builder.appendMemarg(0, 0);
append_simd_store(builder, ip);
} else {
Expand Down Expand Up @@ -3609,6 +3611,33 @@ function emit_simd_3(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
builder.appendU8(WasmOpcode.i32_eqz);
append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
return true;
case SimdIntrinsic3.V128_R4_FLOAT_EQUALITY:
case SimdIntrinsic3.V128_R8_FLOAT_EQUALITY: {
/*
Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
return result.AsInt32() == Vector128<int>.AllBitsSet;
*/
const isR8 = index === SimdIntrinsic3.V128_R8_FLOAT_EQUALITY,
eqOpcode = isR8 ? WasmSimdOpcode.f64x2_eq : WasmSimdOpcode.f32x4_eq;
builder.local("pLocals");
append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
builder.local("math_lhs128", WasmOpcode.tee_local);
append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
builder.local("math_rhs128", WasmOpcode.tee_local);
builder.appendSimd(eqOpcode);
builder.local("math_lhs128");
builder.local("math_lhs128");
builder.appendSimd(eqOpcode);
builder.local("math_rhs128");
builder.local("math_rhs128");
builder.appendSimd(eqOpcode);
builder.appendSimd(WasmSimdOpcode.v128_or);
builder.appendSimd(WasmSimdOpcode.v128_not);
builder.appendSimd(WasmSimdOpcode.v128_or);
builder.appendSimd(isR8 ? WasmSimdOpcode.i64x2_all_true : WasmSimdOpcode.i32x4_all_true);
append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
return true;
}
case SimdIntrinsic3.V128_I1_SHUFFLE: {
// Detect a constant indices vector and turn it into a const. This allows
// v8 to use a more optimized implementation of the swizzle opcode
Expand Down
5 changes: 4 additions & 1 deletion src/mono/wasm/runtime/jiterpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,11 @@ function generate_wasm(
"temp_f64": WasmValtype.f64,
"backbranched": WasmValtype.i32,
};
if (builder.options.enableSimd)
if (builder.options.enableSimd) {
traceLocals["v128_zero"] = WasmValtype.v128;
traceLocals["math_lhs128"] = WasmValtype.v128;
traceLocals["math_rhs128"] = WasmValtype.v128;
}

let keep = true,
traceValue = 0;
Expand Down