Skip to content

Commit

Permalink
Implement dot operation
Browse files Browse the repository at this point in the history
Signed-off-by: Zoltan Herczeg zherczeg.u-szeged@partner.samsung.com
  • Loading branch information
zherczeg authored and clover2123 committed Feb 11, 2025
1 parent 1416ec9 commit 5f28544
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 14 deletions.
18 changes: 9 additions & 9 deletions src/jit/ByteCodeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,15 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)

#elif (defined SLJIT_CONFIG_RISCV && SLJIT_CONFIG_RISCV)

#define OPERAND_TYPE_LIST_SIMD_ARCH \
OL2(OTOp1V128CB, /* SD */ V128 | NOTMP, V128 | NOTMP) \
OL3(OTOp2V128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP | S0 | S1) \
OL3(OTOp1V128Tmp, /* SDT */ V128 | NOTMP, V128 | TMP | S0, V128) \
OL2(OTExtractLaneF32, /* SD */ V128 | TMP, F32) \
OL2(OTExtractLaneF64, /* SD */ V128 | TMP, F64) \
OL3(OTSwizzleV128, /* SSD */ V128 | TMP, V128 | NOTMP, V128 | TMP | S1) \
OL3(OTShuffleV128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP) \
#define OPERAND_TYPE_LIST_SIMD_ARCH \
OL2(OTOp1V128CB, /* SD */ V128 | NOTMP, V128 | NOTMP) \
OL3(OTOp2V128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP | S0 | S1) \
OL3(OTOp1V128Tmp, /* SDT */ V128 | NOTMP, V128 | TMP | S0, V128) \
OL5(OTOp3DotAddV128, /* SSSDT */ V128 | TMP, V128 | TMP, V128 | NOTMP, V128 | TMP | S2, V128) \
OL2(OTExtractLaneF32, /* SD */ V128 | TMP, F32) \
OL2(OTExtractLaneF64, /* SD */ V128 | TMP, F64) \
OL3(OTSwizzleV128, /* SSD */ V128 | TMP, V128 | NOTMP, V128 | TMP | S1) \
OL3(OTShuffleV128, /* SSD */ V128 | TMP, V128 | TMP, V128 | TMP) \
OL3(OTShiftV128, /* SSD */ V128 | NOTMP, I32, V128 | TMP | S0)

// List of aliases.
Expand All @@ -345,7 +346,6 @@ static bool isFloatGlobal(uint32_t globalIndex, Module* module)
#define OTPMinMaxV128 OTOp2V128
#define OTPopcntV128 OTOp1V128Tmp
#define OTShiftV128Tmp OTShiftV128
#define OTOp3DotAddV128 OTOp3V128

#endif /* SLJIT_CONFIG_ARM */

Expand Down
49 changes: 44 additions & 5 deletions src/jit/SimdRiscvInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum TypeOpcode : uint32_t {
vfwcvt_f_f_v = InstructionType::opfvv | OPCODE(0x12) | (0xc << 15),
vfwcvt_f_x_v = InstructionType::opfvv | OPCODE(0x12) | (0xb << 15),
vfwcvt_f_xu_v = InstructionType::opfvv | OPCODE(0x12) | (0xa << 15),
vmacc_vv = InstructionType::opmvv | OPCODE(0x2d),
vmax_vv = InstructionType::opivv | OPCODE(0x7),
vmax_vx = InstructionType::opivx | OPCODE(0x7),
vmaxu_vv = InstructionType::opivv | OPCODE(0x6),
Expand Down Expand Up @@ -297,11 +298,23 @@ static void simdEmitNarrowZero(sljit_compiler* compiler, sljit_s32 type, uint32_
simdEmitTypedOp(compiler, type, opcode, rd, rn, 0, SimdOp::rmIsImm, SimdOp::vlMulF2);
}

static void simdEmitI32x4DotI16x8(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
static void simdEmitDot(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm)
{
sljit_s32 tmp = SLJIT_TMP_DEST_VREG;
simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_32, SimdOp::vmul_vv, tmp, rn, rm);
simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_16, SimdOp::vredsum_vs, rd, tmp, rn);
sljit_s32 tmp1 = SLJIT_TMP_DEST_VREG;
sljit_s32 tmp2 = SLJIT_VR0;
sljit_s32 shift = (type == SLJIT_SIMD_ELEM_32) ? 16 : 8;

simdEmitTypedOp(compiler, type, SimdOp::vsll_vi, tmp1, rn, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsll_vi, tmp2, rm, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp1, tmp1, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp2, tmp2, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vmul_vv, tmp1, tmp1, tmp2);

simdEmitOp(compiler, SimdOp::vsra_vi, tmp2, rn, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, rd, rm, shift, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vmul_vv, rd, rd, tmp2);

simdEmitOp(compiler, SimdOp::vadd_vv, rd, rd, tmp1);
}

static void simdEmitFCeil(sljit_compiler* compiler, sljit_s32 type, sljit_s32 rd, sljit_s32 rn)
Expand Down Expand Up @@ -879,6 +892,7 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I16X8ExtmulHighI8X16SOpcode:
case ByteCode::I16X8ExtmulLowI8X16UOpcode:
case ByteCode::I16X8ExtmulHighI8X16UOpcode:
case ByteCode::I16X8DotI8X16I7X16SOpcode:
srcType = SLJIT_SIMD_ELEM_8;
dstType = SLJIT_SIMD_ELEM_16;
break;
Expand Down Expand Up @@ -1129,7 +1143,8 @@ static void emitBinarySIMD(sljit_compiler* compiler, Instruction* instr)
case ByteCode::I16X8Q15mulrSatSOpcode:
break;
case ByteCode::I32X4DotI16X8SOpcode:
simdEmitI32x4DotI16x8(compiler, srcType, dst, args[0].arg, args[1].arg);
case ByteCode::I16X8DotI8X16I7X16SOpcode:
simdEmitDot(compiler, dstType, dst, args[0].arg, args[1].arg);
break;
case ByteCode::F32X4AddOpcode:
case ByteCode::F64X2AddOpcode:
Expand Down Expand Up @@ -1245,6 +1260,29 @@ static void simdEmitMadd(sljit_compiler* compiler, sljit_s32 type, bool isAdd, s
simdEmitOp(compiler, isAdd ? SimdOp::vfmacc_vv : SimdOp::vfnmsac_vv, rd, rn, rm);
}

static void simdEmitDotAdd(sljit_compiler* compiler, sljit_s32 rd, sljit_s32 rn, sljit_s32 rm, sljit_s32 ro, sljit_s32 tmp3)
{
sljit_s32 tmp1 = SLJIT_TMP_DEST_VREG;
sljit_s32 tmp2 = SLJIT_VR0;

simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_16, SimdOp::vsll_vi, tmp1, rn, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsll_vi, tmp2, rm, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp1, tmp1, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp2, tmp2, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vmul_vv, tmp3, tmp1, tmp2);

simdEmitOp(compiler, SimdOp::vsra_vi, tmp1, rn, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp2, rm, 8, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vmacc_vv, tmp3, tmp1, tmp2);

simdEmitTypedOp(compiler, SLJIT_SIMD_ELEM_32, SimdOp::vsll_vi, tmp1, tmp3, 16, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp3, tmp3, 16, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vsra_vi, tmp1, tmp1, 16, SimdOp::rmIsImm);
simdEmitOp(compiler, SimdOp::vadd_vv, tmp3, tmp3, tmp1);

simdEmitOp(compiler, SimdOp::vadd_vv, rd, ro, tmp3);
}

static void emitTernarySIMD(sljit_compiler* compiler, Instruction* instr)
{
Operand* operands = instr->operands();
Expand Down Expand Up @@ -1312,6 +1350,7 @@ static void emitTernarySIMD(sljit_compiler* compiler, Instruction* instr)
simdEmitOp(compiler, SimdOp::vmerge_vv, dst, args[0].arg, args[1].arg);
break;
case ByteCode::I32X4DotI8X16I7X16AddSOpcode:
simdEmitDotAdd(compiler, dst, args[0].arg, args[1].arg, args[2].arg, instr->requiredReg(3));
break;
case ByteCode::F32X4RelaxedMaddOpcode:
case ByteCode::F64X2RelaxedMaddOpcode:
Expand Down

0 comments on commit 5f28544

Please sign in to comment.