Skip to content

Commit

Permalink
[wasm] Optimize Vector128<float>/<double>.Equals in interp/jiterp (#8…
Browse files Browse the repository at this point in the history
…8064)

* Add browser-bench measurement for int32 and float equals
* Add interp intrinsics for Vector128 float and double Equals methods
* Implement Vector128 float and double Equals methods in jiterp
* Add jiterp validation to make sure we never appendSimd(0) by accident
  • Loading branch information
kg authored Jul 12, 2023
1 parent 6969e7e commit d59af2c
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,16 @@ public static Vector128<T> operator >>>(Vector128<T> value, int shiftCount)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals([NotNullWhen(true)] object? obj) => (obj is Vector128<T> other) && Equals(other);

// Account for floating-point equality around NaN
// This is in a separate method so it can be optimized by the mono interpreter/jiterpreter
[Intrinsic]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool EqualsFloatingPoint (Vector128<T> lhs, Vector128<T> rhs)
{
Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
return result.AsInt32() == Vector128<int>.AllBitsSet;
}

/// <summary>Determines whether the specified <see cref="Vector128{T}" /> is equal to the current instance.</summary>
/// <param name="other">The <see cref="Vector128{T}" /> to compare with the current instance.</param>
/// <returns><c>true</c> if <paramref name="other" /> is equal to the current instance; otherwise, <c>false</c>.</returns>
Expand All @@ -401,8 +411,7 @@ public bool Equals(Vector128<T> other)
{
if ((typeof(T) == typeof(double)) || (typeof(T) == typeof(float)))
{
Vector128<T> result = Vector128.Equals(this, other) | ~(Vector128.Equals(this, this) | Vector128.Equals(other, other));
return result.AsInt32() == Vector128<int>.AllBitsSet;
return EqualsFloatingPoint(this, other);
}
else
{
Expand Down
3 changes: 3 additions & 0 deletions src/mono/mono/mini/interp/interp-simd-intrins.def
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_OR, interp_v128_o
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY, interp_v128_op_bitwise_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_BITWISE_INEQUALITY, interp_v128_op_bitwise_inequality, -1)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY, interp_v128_r4_float_equality, -1)
INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY, interp_v128_r8_float_equality, -1)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR, interp_v128_op_exclusive_or, 81)

INTERP_SIMD_INTRINSIC_P_PP (INTERP_SIMD_INTRINSIC_V128_I1_MULTIPLY, interp_v128_i1_op_multiply, -1)
Expand Down
27 changes: 26 additions & 1 deletion src/mono/mono/mini/interp/interp-simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ typedef guint16 v128_u2 __attribute__ ((vector_size (SIZEOF_V128)));
typedef gint8 v128_i1 __attribute__ ((vector_size (SIZEOF_V128)));
typedef guint8 v128_u1 __attribute__ ((vector_size (SIZEOF_V128)));
typedef float v128_r4 __attribute__ ((vector_size (SIZEOF_V128)));
typedef double v128_r8 __attribute__ ((vector_size (SIZEOF_V128)));

// get_AllBitsSet
static void
Expand Down Expand Up @@ -122,7 +123,30 @@ interp_v128_op_bitwise_inequality (gpointer res, gpointer v1, gpointer v2)
*(gint32*)res = 1;
}

// op_Addition
// Vector128<float>EqualsFloatingPoint
static void
interp_v128_r4_float_equality (gpointer res, gpointer v1, gpointer v2)
{
v128_r4 v1_cast = *(v128_r4*)v1;
v128_r4 v2_cast = *(v128_r4*)v2;
v128_r4 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V128);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
}

static void
interp_v128_r8_float_equality (gpointer res, gpointer v1, gpointer v2)
{
v128_r8 v1_cast = *(v128_r8*)v1;
v128_r8 v2_cast = *(v128_r8*)v2;
v128_r8 result = (v1_cast == v2_cast) | ~((v1_cast == v1_cast) | (v2_cast == v2_cast));
memset (&v1_cast, 0xff, SIZEOF_V128);

*(gint32*)res = memcmp (&v1_cast, &result, SIZEOF_V128) == 0;
}

// op_Multiply
static void
interp_v128_i1_op_multiply (gpointer res, gpointer v1, gpointer v2)
{
Expand All @@ -147,6 +171,7 @@ interp_v128_r4_op_multiply (gpointer res, gpointer v1, gpointer v2)
*(v128_r4*)res = *(v128_r4*)v1 * *(v128_r4*)v2;
}

// op_Division
static void
interp_v128_r4_op_division (gpointer res, gpointer v1, gpointer v2)
{
Expand Down
1 change: 1 addition & 0 deletions src/mono/mono/mini/interp/simd-methods.def
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ SIMD_METHOD(CreateScalar)
SIMD_METHOD(CreateScalarUnsafe)

SIMD_METHOD(Equals)
SIMD_METHOD(EqualsFloatingPoint)
SIMD_METHOD(ExtractMostSignificantBits)
SIMD_METHOD(GreaterThan)
SIMD_METHOD(LessThan)
Expand Down
8 changes: 8 additions & 0 deletions src/mono/mono/mini/interp/transform-simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static guint16 sri_vector128_methods [] = {
};

static guint16 sri_vector128_t_methods [] = {
SN_EqualsFloatingPoint,
SN_get_AllBitsSet,
SN_get_Count,
SN_get_One,
Expand Down Expand Up @@ -196,6 +197,13 @@ emit_common_simd_operations (TransformData *td, int id, int atype, int vector_si
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_BITWISE_EQUALITY;
}
break;
case SN_EqualsFloatingPoint:
*simd_opcode = MINT_SIMD_INTRINS_P_PP;
if (atype == MONO_TYPE_R4)
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_R4_FLOAT_EQUALITY;
else if (atype == MONO_TYPE_R8)
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_R8_FLOAT_EQUALITY;
break;
case SN_op_ExclusiveOr:
*simd_opcode = MINT_SIMD_INTRINS_P_PP;
*simd_intrins = INTERP_SIMD_INTRINSIC_V128_EXCLUSIVE_OR;
Expand Down
38 changes: 38 additions & 0 deletions src/mono/sample/wasm/browser-bench/Vector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public VectorTask()
new MinDouble(),
new MaxDouble(),
new Normalize(),
new EqualsInt32(),
new EqualsFloat(),
};
}

Expand Down Expand Up @@ -344,5 +346,41 @@ public override void RunStep() {
result = vector / (float)Math.Sqrt(Vector128.Dot(vector, vector));
}
}

class EqualsInt32 : VectorMeasurement
{
Vector128<Int32> vector1, vector2;
bool result;

public override string Name => "Equals Int32";

public EqualsInt32()
{
vector1 = Vector128.Create(1, 2, 3, 4);
vector2 = Vector128.Create(4, 3, 2, 1);
}

public override void RunStep() {
result = vector1.Equals(vector2);
}
}

class EqualsFloat : VectorMeasurement
{
Vector128<float> vector1, vector2;
bool result;

public override string Name => "Equals Float";

public EqualsFloat()
{
vector1 = Vector128.Create(1f, 2f, 3f, 4f);
vector2 = Vector128.Create(4f, 3f, 2f, 1f);
}

public override void RunStep() {
result = vector1.Equals(vector2);
}
}
}
}
7 changes: 5 additions & 2 deletions src/mono/wasm/runtime/jiterpreter-support.ts
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,10 @@ export class WasmBuilder {
return this.current.appendU8(value);
}

appendSimd(value: WasmSimdOpcode) {
appendSimd(value: WasmSimdOpcode, allowLoad?: boolean) {
this.current.appendU8(WasmOpcode.PREFIX_simd);
// Yes that's right. We're using LEB128 to encode 8-bit opcodes. Why? I don't know
mono_assert(((value | 0) !== 0) || ((value === WasmSimdOpcode.v128_load) && (allowLoad === true)), "Expected non-v128_load simd opcode or allowLoad==true");
return this.current.appendULeb(value);
}

Expand Down Expand Up @@ -993,6 +994,7 @@ export class BlobBuilder {
}

appendULeb(value: number) {
mono_assert(typeof (value) === "number", () => `appendULeb expected number but got ${value}`);
mono_assert(value >= 0, "cannot pass negative value to appendULeb");
if (value < 0x7F) {
if (this.size + 1 >= this.capacity)
Expand All @@ -1013,6 +1015,7 @@ export class BlobBuilder {
}

appendLeb(value: number) {
mono_assert(typeof (value) === "number", () => `appendLeb expected number but got ${value}`);
if (this.size + 8 >= this.capacity)
throw new Error("Buffer full");

Expand Down Expand Up @@ -1721,7 +1724,7 @@ export function try_append_memmove_fast(
while (count >= sizeofV128) {
builder.local(destLocal);
builder.local(srcLocal);
builder.appendSimd(WasmSimdOpcode.v128_load);
builder.appendSimd(WasmSimdOpcode.v128_load, true);
builder.appendMemarg(srcOffset, 0);
builder.appendSimd(WasmSimdOpcode.v128_store);
builder.appendMemarg(destOffset, 0);
Expand Down
31 changes: 30 additions & 1 deletion src/mono/wasm/runtime/jiterpreter-trace-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,8 @@ function append_ldloc(builder: WasmBuilder, offset: number, opcodeOrPrefix: Wasm
if (simdOpcode !== undefined) {
// This looks wrong but I assure you it's correct.
builder.appendULeb(simdOpcode);
} else if (opcodeOrPrefix === WasmOpcode.PREFIX_simd) {
throw new Error("PREFIX_simd ldloc without a simdOpcode");
}
const alignment = computeMemoryAlignment(offset, opcodeOrPrefix, simdOpcode);
builder.appendMemarg(offset, alignment);
Expand Down Expand Up @@ -3493,7 +3495,7 @@ function emit_simd_2(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
// Indirect load, so v1 is T** and res is Vector128*
builder.local("pLocals");
append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.i32_load);
builder.appendSimd(simple);
builder.appendSimd(simple, true);
builder.appendMemarg(0, 0);
append_simd_store(builder, ip);
} else {
Expand Down Expand Up @@ -3609,6 +3611,33 @@ function emit_simd_3(builder: WasmBuilder, ip: MintOpcodePtr, index: SimdIntrins
builder.appendU8(WasmOpcode.i32_eqz);
append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
return true;
case SimdIntrinsic3.V128_R4_FLOAT_EQUALITY:
case SimdIntrinsic3.V128_R8_FLOAT_EQUALITY: {
/*
Vector128<T> result = Vector128.Equals(lhs, rhs) | ~(Vector128.Equals(lhs, lhs) | Vector128.Equals(rhs, rhs));
return result.AsInt32() == Vector128<int>.AllBitsSet;
*/
const isR8 = index === SimdIntrinsic3.V128_R8_FLOAT_EQUALITY,
eqOpcode = isR8 ? WasmSimdOpcode.f64x2_eq : WasmSimdOpcode.f32x4_eq;
builder.local("pLocals");
append_ldloc(builder, getArgU16(ip, 2), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
builder.local("math_lhs128", WasmOpcode.tee_local);
append_ldloc(builder, getArgU16(ip, 3), WasmOpcode.PREFIX_simd, WasmSimdOpcode.v128_load);
builder.local("math_rhs128", WasmOpcode.tee_local);
builder.appendSimd(eqOpcode);
builder.local("math_lhs128");
builder.local("math_lhs128");
builder.appendSimd(eqOpcode);
builder.local("math_rhs128");
builder.local("math_rhs128");
builder.appendSimd(eqOpcode);
builder.appendSimd(WasmSimdOpcode.v128_or);
builder.appendSimd(WasmSimdOpcode.v128_not);
builder.appendSimd(WasmSimdOpcode.v128_or);
builder.appendSimd(isR8 ? WasmSimdOpcode.i64x2_all_true : WasmSimdOpcode.i32x4_all_true);
append_stloc_tail(builder, getArgU16(ip, 1), WasmOpcode.i32_store);
return true;
}
case SimdIntrinsic3.V128_I1_SHUFFLE: {
// Detect a constant indices vector and turn it into a const. This allows
// v8 to use a more optimized implementation of the swizzle opcode
Expand Down
5 changes: 4 additions & 1 deletion src/mono/wasm/runtime/jiterpreter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -795,8 +795,11 @@ function generate_wasm(
"temp_f64": WasmValtype.f64,
"backbranched": WasmValtype.i32,
};
if (builder.options.enableSimd)
if (builder.options.enableSimd) {
traceLocals["v128_zero"] = WasmValtype.v128;
traceLocals["math_lhs128"] = WasmValtype.v128;
traceLocals["math_rhs128"] = WasmValtype.v128;
}

let keep = true,
traceValue = 0;
Expand Down

0 comments on commit d59af2c

Please sign in to comment.