Skip to content

Commit

Permalink
CodeGen: Rewrite dot product lowering using a dedicated IR instruction (
Browse files Browse the repository at this point in the history
#1512)

Instead of doing the dot product related math in scalar IR, we lift the
computation into a dedicated IR instruction.

On x64, we can use VDPPS which was more or less tailor made for this
purpose. This is better than manual scalar lowering that requires
reloading components from memory; it's not always a strict improvement
over the shuffle+add version (which we never had), but this can now be
adjusted in the IR lowering in an optimal fashion (maybe even based on
CPU vendor, although that'd create issues for offline compilation).

On A64, we can either use naive adds or paired adds, as there is no
dedicated vector-wide horizontal instruction until SVE. Both run at
about the same performance on M2, but paired adds require fewer
instructions and temporaries.

I've measured this using mesh-normal-vector benchmark, changing the
benchmark to just report the time of the second loop inside
`calculate_normals`, testing master vs #1504 vs this PR, also increasing
the grid size to 400 for more stable timings.

On Zen 4 (7950X), this PR is comfortably ~8% faster vs master, while I
see neutral to negative results in #1504.
On M2 (base), this PR is ~28% faster vs master, while #1504 is only
about ~10% faster.

If I measure the second loop in `calculate_tangent_space` instead, I
get:

On Zen 4 (7950X), this PR is ~12% faster vs master, while #1504 is ~3%
faster
On M2 (base), this PR is ~24% faster vs master, while #1504 is only
about ~13% faster.

Note that the loops in question are not quite optimal, as they store and
reload various vectors to dictionary values due to inappropriate use of
locals. The underlying gains in individual functions are thus larger
than the numbers above; for example, changing the `calculate_normals`
loop to use a local variable to store the normalized vector (but still
saving the result to dictionary value), I get a ~24% performance
increase from this PR on Zen4 vs master instead of just 8% (#1504 is
~15% slower in this setup).
  • Loading branch information
zeux authored Nov 9, 2024
1 parent a36a3c4 commit e6bf718
Show file tree
Hide file tree
Showing 14 changed files with 135 additions and 32 deletions.
1 change: 1 addition & 0 deletions CodeGen/include/Luau/AssemblyBuilderA64.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class AssemblyBuilderA64
void fneg(RegisterA64 dst, RegisterA64 src);
void fsqrt(RegisterA64 dst, RegisterA64 src);
void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void faddp(RegisterA64 dst, RegisterA64 src);

// Vector component manipulation
void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index);
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/include/Luau/AssemblyBuilderX64.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class AssemblyBuilderX64
void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle);
void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset);

void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask);

// Run final checks
bool finalize();

Expand Down
4 changes: 4 additions & 0 deletions CodeGen/include/Luau/IrData.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ enum class IrCmd : uint8_t
// A: TValue
UNM_VEC,

// Compute dot product between two vectors
// A, B: TValue
DOT_VEC,

// Compute Luau 'not' operation on destructured TValue
// A: tag
// B: int (value)
Expand Down
1 change: 1 addition & 0 deletions CodeGen/include/Luau/IrUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
case IrCmd::UNM_VEC:
case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY:
Expand Down
8 changes: 8 additions & 0 deletions CodeGen/src/AssemblyBuilderA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src)
placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000);
}

void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src)
{
CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s);
CODEGEN_ASSERT(dst.kind == src.kind);

placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12));
}

void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
if (dst.kind == KindA64::d)
Expand Down
5 changes: 5 additions & 0 deletions CodeGen/src/AssemblyBuilderX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s
placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66);
}

void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask)
{
placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66);
}

bool AssemblyBuilderX64::finalize()
{
code.resize(codePos - code.data());
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/src/IrDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd)
return "DIV_VEC";
case IrCmd::UNM_VEC:
return "UNM_VEC";
case IrCmd::DOT_VEC:
return "DOT_VEC";
case IrCmd::NOT_ANY:
return "NOT_ANY";
case IrCmd::CMP_ANY:
Expand Down
15 changes: 15 additions & 0 deletions CodeGen/src/IrLoweringA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fneg(inst.regA64, regOp(inst.a));
break;
}
case IrCmd::DOT_VEC:
{
inst.regA64 = regs.allocReg(KindA64::d, index);

RegisterA64 temp = regs.allocTemp(KindA64::q);
RegisterA64 temps = castReg(KindA64::s, temp);
RegisterA64 regs = castReg(KindA64::s, inst.regA64);

build.fmul(temp, regOp(inst.a), regOp(inst.b));
build.faddp(regs, temps); // x+y
build.dup_4s(temp, temp, 2);
build.fadd(regs, regs, temps); // +z
build.fcvt(inst.regA64, regs);
break;
}
case IrCmd::NOT_ANY:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
Expand Down
14 changes: 14 additions & 0 deletions CodeGen/src/IrLoweringX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0));
break;
}
case IrCmd::DOT_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});

ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};

RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);

build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float
build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64);
break;
}
case IrCmd::NOT_ANY:
{
// TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target
Expand Down
104 changes: 73 additions & 31 deletions CodeGen/src/IrTranslateBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5;
static const int kBit32BinaryOpUnrolledParams = 5;

LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);

namespace Luau
{
Expand Down Expand Up @@ -907,15 +908,26 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(

build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));

IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
IrOp sum;

if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));

IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
sum = build.inst(IrCmd::DOT_VEC, a, a);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);

sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
}

IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);

Expand Down Expand Up @@ -945,25 +957,43 @@ static BuiltinImplResult translateBuiltinVectorNormalize(

build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));

IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp sum = build.inst(IrCmd::DOT_VEC, a, a);

IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv);

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec);

IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
result = build.inst(IrCmd::TAG_VECTOR, result);

IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));

build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);

IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);

IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv);

build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
}

return {BuiltinImplType::Full, 1};
}
Expand Down Expand Up @@ -1019,19 +1049,31 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos));

IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp sum;

IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0));

IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);
sum = build.inst(IrCmd::DOT_VEC, a, b);
}
else
{
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);

IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);

sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
}

build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/src/IrUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::DIV_VEC:
case IrCmd::UNM_VEC:
return IrValueKind::Tvalue;
case IrCmd::DOT_VEC:
return IrValueKind::Double;
case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY:
return IrValueKind::Int;
Expand Down
4 changes: 3 additions & 1 deletion CodeGen/src/OptimizeConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
if (tag == LUA_TBOOLEAN &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int)))
canSplitTvalueStore = true;
else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double)))
else if (tag == LUA_TNUMBER &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double)))
canSplitTvalueStore = true;
else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst)
canSplitTvalueStore = true;
Expand Down Expand Up @@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR)
replace(function, inst.a, a->a);

Expand Down
3 changes: 3 additions & 0 deletions tests/AssemblyBuilderA64.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath")
SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841);
SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD);

SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D);
SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D);

SINGLE_COMPARE(frinta(d1, d2), 0x1E664041);
SINGLE_COMPARE(frintm(d1, d2), 0x1E654041);
SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041);
Expand Down
2 changes: 2 additions & 0 deletions tests/AssemblyBuilderX64.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms")

SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4);
SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02);

SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02);
}

TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")
Expand Down

0 comments on commit e6bf718

Please sign in to comment.