Skip to content

Commit

Permalink
RISC-V: Add tuple type vget/vset intrinsics
Browse files Browse the repository at this point in the history
gcc/ChangeLog:

	* config/riscv/genrvv-type-indexer.cc (valid_type): Adapt for
	tuple type support.
	(inttype): Ditto.
	(floattype): Ditto.
	(main): Ditto.
	* config/riscv/riscv-vector-builtins-bases.cc: Ditto.
	* config/riscv/riscv-vector-builtins-functions.def (vset): Add
	tuple type vset.
	(vget): Add tuple type vget.
	* config/riscv/riscv-vector-builtins-types.def
	(DEF_RVV_TUPLE_OPS): New macro.
	(vint8mf8x2_t): Ditto.
	(vuint8mf8x2_t): Ditto.
	(vint8mf8x3_t): Ditto.
	(vuint8mf8x3_t): Ditto.
	(vint8mf8x4_t): Ditto.
	(vuint8mf8x4_t): Ditto.
	(vint8mf8x5_t): Ditto.
	(vuint8mf8x5_t): Ditto.
	(vint8mf8x6_t): Ditto.
	(vuint8mf8x6_t): Ditto.
	(vint8mf8x7_t): Ditto.
	(vuint8mf8x7_t): Ditto.
	(vint8mf8x8_t): Ditto.
	(vuint8mf8x8_t): Ditto.
	(vint8mf4x2_t): Ditto.
	(vuint8mf4x2_t): Ditto.
	(vint8mf4x3_t): Ditto.
	(vuint8mf4x3_t): Ditto.
	(vint8mf4x4_t): Ditto.
	(vuint8mf4x4_t): Ditto.
	(vint8mf4x5_t): Ditto.
	(vuint8mf4x5_t): Ditto.
	(vint8mf4x6_t): Ditto.
	(vuint8mf4x6_t): Ditto.
	(vint8mf4x7_t): Ditto.
	(vuint8mf4x7_t): Ditto.
	(vint8mf4x8_t): Ditto.
	(vuint8mf4x8_t): Ditto.
	(vint8mf2x2_t): Ditto.
	(vuint8mf2x2_t): Ditto.
	(vint8mf2x3_t): Ditto.
	(vuint8mf2x3_t): Ditto.
	(vint8mf2x4_t): Ditto.
	(vuint8mf2x4_t): Ditto.
	(vint8mf2x5_t): Ditto.
	(vuint8mf2x5_t): Ditto.
	(vint8mf2x6_t): Ditto.
	(vuint8mf2x6_t): Ditto.
	(vint8mf2x7_t): Ditto.
	(vuint8mf2x7_t): Ditto.
	(vint8mf2x8_t): Ditto.
	(vuint8mf2x8_t): Ditto.
	(vint8m1x2_t): Ditto.
	(vuint8m1x2_t): Ditto.
	(vint8m1x3_t): Ditto.
	(vuint8m1x3_t): Ditto.
	(vint8m1x4_t): Ditto.
	(vuint8m1x4_t): Ditto.
	(vint8m1x5_t): Ditto.
	(vuint8m1x5_t): Ditto.
	(vint8m1x6_t): Ditto.
	(vuint8m1x6_t): Ditto.
	(vint8m1x7_t): Ditto.
	(vuint8m1x7_t): Ditto.
	(vint8m1x8_t): Ditto.
	(vuint8m1x8_t): Ditto.
	(vint8m2x2_t): Ditto.
	(vuint8m2x2_t): Ditto.
	(vint8m2x3_t): Ditto.
	(vuint8m2x3_t): Ditto.
	(vint8m2x4_t): Ditto.
	(vuint8m2x4_t): Ditto.
	(vint8m4x2_t): Ditto.
	(vuint8m4x2_t): Ditto.
	(vint16mf4x2_t): Ditto.
	(vuint16mf4x2_t): Ditto.
	(vint16mf4x3_t): Ditto.
	(vuint16mf4x3_t): Ditto.
	(vint16mf4x4_t): Ditto.
	(vuint16mf4x4_t): Ditto.
	(vint16mf4x5_t): Ditto.
	(vuint16mf4x5_t): Ditto.
	(vint16mf4x6_t): Ditto.
	(vuint16mf4x6_t): Ditto.
	(vint16mf4x7_t): Ditto.
	(vuint16mf4x7_t): Ditto.
	(vint16mf4x8_t): Ditto.
	(vuint16mf4x8_t): Ditto.
	(vint16mf2x2_t): Ditto.
	(vuint16mf2x2_t): Ditto.
	(vint16mf2x3_t): Ditto.
	(vuint16mf2x3_t): Ditto.
	(vint16mf2x4_t): Ditto.
	(vuint16mf2x4_t): Ditto.
	(vint16mf2x5_t): Ditto.
	(vuint16mf2x5_t): Ditto.
	(vint16mf2x6_t): Ditto.
	(vuint16mf2x6_t): Ditto.
	(vint16mf2x7_t): Ditto.
	(vuint16mf2x7_t): Ditto.
	(vint16mf2x8_t): Ditto.
	(vuint16mf2x8_t): Ditto.
	(vint16m1x2_t): Ditto.
	(vuint16m1x2_t): Ditto.
	(vint16m1x3_t): Ditto.
	(vuint16m1x3_t): Ditto.
	(vint16m1x4_t): Ditto.
	(vuint16m1x4_t): Ditto.
	(vint16m1x5_t): Ditto.
	(vuint16m1x5_t): Ditto.
	(vint16m1x6_t): Ditto.
	(vuint16m1x6_t): Ditto.
	(vint16m1x7_t): Ditto.
	(vuint16m1x7_t): Ditto.
	(vint16m1x8_t): Ditto.
	(vuint16m1x8_t): Ditto.
	(vint16m2x2_t): Ditto.
	(vuint16m2x2_t): Ditto.
	(vint16m2x3_t): Ditto.
	(vuint16m2x3_t): Ditto.
	(vint16m2x4_t): Ditto.
	(vuint16m2x4_t): Ditto.
	(vint16m4x2_t): Ditto.
	(vuint16m4x2_t): Ditto.
	(vint32mf2x2_t): Ditto.
	(vuint32mf2x2_t): Ditto.
	(vint32mf2x3_t): Ditto.
	(vuint32mf2x3_t): Ditto.
	(vint32mf2x4_t): Ditto.
	(vuint32mf2x4_t): Ditto.
	(vint32mf2x5_t): Ditto.
	(vuint32mf2x5_t): Ditto.
	(vint32mf2x6_t): Ditto.
	(vuint32mf2x6_t): Ditto.
	(vint32mf2x7_t): Ditto.
	(vuint32mf2x7_t): Ditto.
	(vint32mf2x8_t): Ditto.
	(vuint32mf2x8_t): Ditto.
	(vint32m1x2_t): Ditto.
	(vuint32m1x2_t): Ditto.
	(vint32m1x3_t): Ditto.
	(vuint32m1x3_t): Ditto.
	(vint32m1x4_t): Ditto.
	(vuint32m1x4_t): Ditto.
	(vint32m1x5_t): Ditto.
	(vuint32m1x5_t): Ditto.
	(vint32m1x6_t): Ditto.
	(vuint32m1x6_t): Ditto.
	(vint32m1x7_t): Ditto.
	(vuint32m1x7_t): Ditto.
	(vint32m1x8_t): Ditto.
	(vuint32m1x8_t): Ditto.
	(vint32m2x2_t): Ditto.
	(vuint32m2x2_t): Ditto.
	(vint32m2x3_t): Ditto.
	(vuint32m2x3_t): Ditto.
	(vint32m2x4_t): Ditto.
	(vuint32m2x4_t): Ditto.
	(vint32m4x2_t): Ditto.
	(vuint32m4x2_t): Ditto.
	(vint64m1x2_t): Ditto.
	(vuint64m1x2_t): Ditto.
	(vint64m1x3_t): Ditto.
	(vuint64m1x3_t): Ditto.
	(vint64m1x4_t): Ditto.
	(vuint64m1x4_t): Ditto.
	(vint64m1x5_t): Ditto.
	(vuint64m1x5_t): Ditto.
	(vint64m1x6_t): Ditto.
	(vuint64m1x6_t): Ditto.
	(vint64m1x7_t): Ditto.
	(vuint64m1x7_t): Ditto.
	(vint64m1x8_t): Ditto.
	(vuint64m1x8_t): Ditto.
	(vint64m2x2_t): Ditto.
	(vuint64m2x2_t): Ditto.
	(vint64m2x3_t): Ditto.
	(vuint64m2x3_t): Ditto.
	(vint64m2x4_t): Ditto.
	(vuint64m2x4_t): Ditto.
	(vint64m4x2_t): Ditto.
	(vuint64m4x2_t): Ditto.
	(vfloat32mf2x2_t): Ditto.
	(vfloat32mf2x3_t): Ditto.
	(vfloat32mf2x4_t): Ditto.
	(vfloat32mf2x5_t): Ditto.
	(vfloat32mf2x6_t): Ditto.
	(vfloat32mf2x7_t): Ditto.
	(vfloat32mf2x8_t): Ditto.
	(vfloat32m1x2_t): Ditto.
	(vfloat32m1x3_t): Ditto.
	(vfloat32m1x4_t): Ditto.
	(vfloat32m1x5_t): Ditto.
	(vfloat32m1x6_t): Ditto.
	(vfloat32m1x7_t): Ditto.
	(vfloat32m1x8_t): Ditto.
	(vfloat32m2x2_t): Ditto.
	(vfloat32m2x3_t): Ditto.
	(vfloat32m2x4_t): Ditto.
	(vfloat32m4x2_t): Ditto.
	(vfloat64m1x2_t): Ditto.
	(vfloat64m1x3_t): Ditto.
	(vfloat64m1x4_t): Ditto.
	(vfloat64m1x5_t): Ditto.
	(vfloat64m1x6_t): Ditto.
	(vfloat64m1x7_t): Ditto.
	(vfloat64m1x8_t): Ditto.
	(vfloat64m2x2_t): Ditto.
	(vfloat64m2x3_t): Ditto.
	(vfloat64m2x4_t): Ditto.
	(vfloat64m4x2_t): Ditto.
	* config/riscv/riscv-vector-builtins.cc (DEF_RVV_TUPLE_OPS):
	Ditto.
	(DEF_RVV_TYPE_INDEX): Ditto.
	(rvv_arg_type_info::get_tuple_subpart_type): New function.
	(DEF_RVV_TUPLE_TYPE): New macro.
	* config/riscv/riscv-vector-builtins.def (DEF_RVV_TYPE_INDEX):
	Adapt for tuple vget/vset support.
	(vint8mf4_t): Ditto.
	(vuint8mf4_t): Ditto.
	(vint8mf2_t): Ditto.
	(vuint8mf2_t): Ditto.
	(vint8m1_t): Ditto.
	(vuint8m1_t): Ditto.
	(vint8m2_t): Ditto.
	(vuint8m2_t): Ditto.
	(vint8m4_t): Ditto.
	(vuint8m4_t): Ditto.
	(vint8m8_t): Ditto.
	(vuint8m8_t): Ditto.
	(vint16mf4_t): Ditto.
	(vuint16mf4_t): Ditto.
	(vint16mf2_t): Ditto.
	(vuint16mf2_t): Ditto.
	(vint16m1_t): Ditto.
	(vuint16m1_t): Ditto.
	(vint16m2_t): Ditto.
	(vuint16m2_t): Ditto.
	(vint16m4_t): Ditto.
	(vuint16m4_t): Ditto.
	(vint16m8_t): Ditto.
	(vuint16m8_t): Ditto.
	(vint32mf2_t): Ditto.
	(vuint32mf2_t): Ditto.
	(vint32m1_t): Ditto.
	(vuint32m1_t): Ditto.
	(vint32m2_t): Ditto.
	(vuint32m2_t): Ditto.
	(vint32m4_t): Ditto.
	(vuint32m4_t): Ditto.
	(vint32m8_t): Ditto.
	(vuint32m8_t): Ditto.
	(vint64m1_t): Ditto.
	(vuint64m1_t): Ditto.
	(vint64m2_t): Ditto.
	(vuint64m2_t): Ditto.
	(vint64m4_t): Ditto.
	(vuint64m4_t): Ditto.
	(vint64m8_t): Ditto.
	(vuint64m8_t): Ditto.
	(vfloat32mf2_t): Ditto.
	(vfloat32m1_t): Ditto.
	(vfloat32m2_t): Ditto.
	(vfloat32m4_t): Ditto.
	(vfloat32m8_t): Ditto.
	(vfloat64m1_t): Ditto.
	(vfloat64m2_t): Ditto.
	(vfloat64m4_t): Ditto.
	(vfloat64m8_t): Ditto.
	(tuple_subpart): Add tuple subpart base type.
	* config/riscv/riscv-vector-builtins.h (struct
	rvv_arg_type_info): Ditto.
	(tuple_type_field): New function.

Signed-off-by: Ju-Zhe Zhong <juzhe.zhong@rivai.ai>
  • Loading branch information
zhongjuzhe authored and Liaoshihua committed Mar 12, 2024
1 parent 7fbb1bb commit 678d6cd
Show file tree
Hide file tree
Showing 7 changed files with 688 additions and 321 deletions.
255 changes: 156 additions & 99 deletions gcc/config/riscv/genrvv-type-indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ valid_type (unsigned sew, int lmul_log2, bool float_p)
}
}

bool
valid_type (unsigned sew, int lmul_log2, unsigned nf, bool float_p)
{
if (!valid_type (sew, lmul_log2, float_p))
return false;

if (nf > 8 || nf < 1)
return false;

switch (lmul_log2)
{
case 1:
return nf < 5;
case 2:
return nf < 3;
case 3:
return nf == 1;
default:
return true;
}
}

std::string
inttype (unsigned sew, int lmul_log2, bool unsigned_p)
{
Expand All @@ -74,6 +96,23 @@ inttype (unsigned sew, int lmul_log2, bool unsigned_p)
return mode.str ();
}

std::string
inttype (unsigned sew, int lmul_log2, unsigned nf, bool unsigned_p)
{
if (!valid_type (sew, lmul_log2, nf, /*float_t*/ false))
return "INVALID";

std::stringstream mode;
mode << "v";
if (unsigned_p)
mode << "u";
mode << "int" << sew << to_lmul (lmul_log2);
if (nf > 1)
mode << "x" << nf;
mode << "_t";
return mode.str ();
}

std::string
floattype (unsigned sew, int lmul_log2)
{
Expand All @@ -85,6 +124,20 @@ floattype (unsigned sew, int lmul_log2)
return mode.str ();
}

std::string
floattype (unsigned sew, int lmul_log2, unsigned nf)
{
if (!valid_type (sew, lmul_log2, nf, /*float_t*/ true))
return "INVALID";

std::stringstream mode;
mode << "vfloat" << sew << to_lmul (lmul_log2);
if (nf > 1)
mode << "x" << nf;
mode << "_t";
return mode.str ();
}

std::string
maskmode (unsigned sew, int lmul_log2)
{
Expand Down Expand Up @@ -168,24 +221,104 @@ main (int argc, const char **argv)
for (unsigned lmul_log2_offset : {1, 2, 3, 4, 5, 6})
{
unsigned multiple_of_lmul = 1 << lmul_log2_offset;
const char *comma = lmul_log2_offset == 6 ? "" : ",";
fprintf (fp, " /*X%d_INTERPRET*/ INVALID%s\n", multiple_of_lmul,
comma);
fprintf (fp, " /*X%d_INTERPRET*/ INVALID,\n", multiple_of_lmul);
}
fprintf (fp, " /*TUPLE_SUBPART*/ INVALID\n");
fprintf (fp, ")\n");
}

// Build for vint and vuint
for (unsigned sew : {8, 16, 32, 64})
for (int lmul_log2 : {-3, -2, -1, 0, 1, 2, 3})
for (bool unsigned_p : {false, true})
for (unsigned nf : {1, 2, 3, 4, 5, 6, 7, 8})
for (bool unsigned_p : {false, true})
{
if (!valid_type (sew, lmul_log2, nf, /*float_t*/ false))
continue;

fprintf (fp, "DEF_RVV_TYPE_INDEX (\n");
fprintf (fp, " /*VECTOR*/ %s,\n",
inttype (sew, lmul_log2, nf, unsigned_p).c_str ());
fprintf (fp, " /*MASK*/ %s,\n",
maskmode (sew, lmul_log2).c_str ());
fprintf (fp, " /*SIGNED*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ false).c_str ());
fprintf (fp, " /*UNSIGNED*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ true).c_str ());
for (unsigned eew : {8, 16, 32, 64})
fprintf (fp, " /*EEW%d_INDEX*/ %s,\n", eew,
same_ratio_eew_type (sew, lmul_log2, eew,
/*unsigned_p*/ true, false)
.c_str ());
fprintf (fp, " /*SHIFT*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ true).c_str ());
fprintf (fp, " /*DOUBLE_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, unsigned_p,
false)
.c_str ());
fprintf (fp, " /*QUAD_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 4, unsigned_p,
false)
.c_str ());
fprintf (fp, " /*OCT_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 8, unsigned_p,
false)
.c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_SCALAR*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, unsigned_p,
false)
.c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_SIGNED*/ INVALID,\n");
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, true, false)
.c_str ());
if (unsigned_p)
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED_SCALAR*/ INVALID,\n");
else
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED_SCALAR*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, true,
false)
.c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_FLOAT*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true)
.c_str ());
fprintf (fp, " /*FLOAT*/ %s,\n",
floattype (sew, lmul_log2).c_str ());
fprintf (fp, " /*LMUL1*/ %s,\n",
inttype (sew, /*lmul_log2*/ 0, unsigned_p).c_str ());
fprintf (fp, " /*WLMUL1*/ %s,\n",
inttype (sew * 2, /*lmul_log2*/ 0, unsigned_p).c_str ());
for (unsigned eew : {8, 16, 32, 64})
{
if (eew == sew)
fprintf (fp, " /*EEW%d_INTERPRET*/ INVALID,\n", eew);
else
fprintf (fp, " /*EEW%d_INTERPRET*/ %s,\n", eew,
inttype (eew, lmul_log2, unsigned_p).c_str ());
}

for (unsigned lmul_log2_offset : {1, 2, 3, 4, 5, 6})
{
unsigned multiple_of_lmul = 1 << lmul_log2_offset;
fprintf (fp, " /*X%d_VLMUL_EXT*/ %s,\n", multiple_of_lmul,
inttype (sew, lmul_log2 + lmul_log2_offset, unsigned_p)
.c_str ());
}
fprintf (fp, " /*TUPLE_SUBPART*/ %s\n",
inttype (sew, lmul_log2, 1, unsigned_p).c_str ());
fprintf (fp, ")\n");
}
// Build for vfloat
for (unsigned sew : {32, 64})
for (int lmul_log2 : {-3, -2, -1, 0, 1, 2, 3})
for (unsigned nf : {1, 2, 3, 4, 5, 6, 7, 8})
{
if (!valid_type (sew, lmul_log2, /*float_t*/ false))
if (!valid_type (sew, lmul_log2, nf, /*float_t*/ true))
continue;

fprintf (fp, "DEF_RVV_TYPE_INDEX (\n");
fprintf (fp, " /*VECTOR*/ %s,\n",
inttype (sew, lmul_log2, unsigned_p).c_str ());
floattype (sew, lmul_log2, nf).c_str ());
fprintf (fp, " /*MASK*/ %s,\n", maskmode (sew, lmul_log2).c_str ());
fprintf (fp, " /*SIGNED*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ false).c_str ());
Expand All @@ -196,118 +329,42 @@ main (int argc, const char **argv)
same_ratio_eew_type (sew, lmul_log2, eew,
/*unsigned_p*/ true, false)
.c_str ());
fprintf (fp, " /*SHIFT*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ true).c_str ());
fprintf (fp, " /*SHIFT*/ INVALID,\n");
fprintf (fp, " /*DOUBLE_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, unsigned_p,
false)
.c_str ());
fprintf (fp, " /*QUAD_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 4, unsigned_p,
false)
.c_str ());
fprintf (fp, " /*OCT_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 8, unsigned_p,
false)
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true)
.c_str ());
fprintf (fp, " /*QUAD_TRUNC*/ INVALID,\n");
fprintf (fp, " /*OCT_TRUNC*/ INVALID,\n");
fprintf (fp, " /*DOUBLE_TRUNC_SCALAR*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, unsigned_p,
false)
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true)
.c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_SIGNED*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, false)
.c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_SIGNED*/ INVALID,\n");
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, true, false)
.c_str ());
if (unsigned_p)
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED_SCALAR*/ INVALID,\n");
else
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED_SCALAR*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, true, false)
.c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED_SCALAR*/ INVALID,\n");
fprintf (fp, " /*DOUBLE_TRUNC_FLOAT*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true)
.c_str ());
fprintf (fp, " /*FLOAT*/ %s,\n",
floattype (sew, lmul_log2).c_str ());
fprintf (fp, " /*FLOAT*/ INVALID,\n");
fprintf (fp, " /*LMUL1*/ %s,\n",
inttype (sew, /*lmul_log2*/ 0, unsigned_p).c_str ());
floattype (sew, /*lmul_log2*/ 0).c_str ());
fprintf (fp, " /*WLMUL1*/ %s,\n",
inttype (sew * 2, /*lmul_log2*/ 0, unsigned_p).c_str ());
floattype (sew * 2, /*lmul_log2*/ 0).c_str ());
for (unsigned eew : {8, 16, 32, 64})
{
if (eew == sew)
fprintf (fp, " /*EEW%d_INTERPRET*/ INVALID,\n", eew);
else
fprintf (fp, " /*EEW%d_INTERPRET*/ %s,\n", eew,
inttype (eew, lmul_log2, unsigned_p).c_str ());
}

fprintf (fp, " /*EEW%d_INTERPRET*/ INVALID,\n", eew);
for (unsigned lmul_log2_offset : {1, 2, 3, 4, 5, 6})
{
unsigned multiple_of_lmul = 1 << lmul_log2_offset;
const char *comma = lmul_log2_offset == 6 ? "" : ",";
fprintf (fp, " /*X%d_VLMUL_EXT*/ %s%s\n", multiple_of_lmul,
inttype (sew, lmul_log2 + lmul_log2_offset, unsigned_p)
.c_str (),
comma);
fprintf (fp, " /*X%d_VLMUL_EXT*/ %s,\n", multiple_of_lmul,
floattype (sew, lmul_log2 + lmul_log2_offset).c_str ());
}
fprintf (fp, " /*TUPLE_SUBPART*/ %s\n",
floattype (sew, lmul_log2, 1).c_str ());
fprintf (fp, ")\n");
}
// Build for vfloat
for (unsigned sew : {32, 64})
for (int lmul_log2 : {-3, -2, -1, 0, 1, 2, 3})
{
if (!valid_type (sew, lmul_log2, /*float_t*/ true))
continue;

fprintf (fp, "DEF_RVV_TYPE_INDEX (\n");
fprintf (fp, " /*VECTOR*/ %s,\n", floattype (sew, lmul_log2).c_str ());
fprintf (fp, " /*MASK*/ %s,\n", maskmode (sew, lmul_log2).c_str ());
fprintf (fp, " /*SIGNED*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ false).c_str ());
fprintf (fp, " /*UNSIGNED*/ %s,\n",
inttype (sew, lmul_log2, /*unsigned_p*/ true).c_str ());
for (unsigned eew : {8, 16, 32, 64})
fprintf (fp, " /*EEW%d_INDEX*/ %s,\n", eew,
same_ratio_eew_type (sew, lmul_log2, eew,
/*unsigned_p*/ true, false)
.c_str ());
fprintf (fp, " /*SHIFT*/ INVALID,\n");
fprintf (
fp, " /*DOUBLE_TRUNC*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true).c_str ());
fprintf (fp, " /*QUAD_TRUNC*/ INVALID,\n");
fprintf (fp, " /*OCT_TRUNC*/ INVALID,\n");
fprintf (
fp, " /*DOUBLE_TRUNC_SCALAR*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true).c_str ());
fprintf (
fp, " /*DOUBLE_TRUNC_SIGNED*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, false).c_str ());
fprintf (
fp, " /*DOUBLE_TRUNC_UNSIGNED*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, true, false).c_str ());
fprintf (fp, " /*DOUBLE_TRUNC_UNSIGNED_SCALAR*/ INVALID,\n");
fprintf (
fp, " /*DOUBLE_TRUNC_FLOAT*/ %s,\n",
same_ratio_eew_type (sew, lmul_log2, sew / 2, false, true).c_str ());
fprintf (fp, " /*FLOAT*/ INVALID,\n");
fprintf (fp, " /*LMUL1*/ %s,\n",
floattype (sew, /*lmul_log2*/ 0).c_str ());
fprintf (fp, " /*WLMUL1*/ %s,\n",
floattype (sew * 2, /*lmul_log2*/ 0).c_str ());
for (unsigned eew : {8, 16, 32, 64})
fprintf (fp, " /*EEW%d_INTERPRET*/ INVALID,\n", eew);
for (unsigned lmul_log2_offset : {1, 2, 3, 4, 5, 6})
{
unsigned multiple_of_lmul = 1 << lmul_log2_offset;
const char *comma = lmul_log2_offset == 6 ? "" : ",";
fprintf (fp, " /*X%d_VLMUL_EXT*/ %s%s\n", multiple_of_lmul,
floattype (sew, lmul_log2 + lmul_log2_offset).c_str (),
comma);
}
fprintf (fp, ")\n");
}

return 0;
}
49 changes: 49 additions & 0 deletions gcc/config/riscv/riscv-vector-builtins-bases.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1548,9 +1548,40 @@ class vset : public function_base
public:
bool apply_vl_p () const override { return false; }

gimple *fold (gimple_folder &f) const override
{
tree rhs_tuple = gimple_call_arg (f.call, 0);
/* LMUL > 1 non-tuple vector types are not structure,
we can't use __val[index] to set the subpart. */
if (!riscv_v_ext_tuple_mode_p (TYPE_MODE (TREE_TYPE (rhs_tuple))))
return NULL;
tree index = gimple_call_arg (f.call, 1);
tree rhs_vector = gimple_call_arg (f.call, 2);

/* Replace the call with two statements: a copy of the full tuple
to the call result, followed by an update of the individual vector.
The fold routines expect the replacement statement to have the
same lhs as the original call, so return the copy statement
rather than the field update. */
gassign *copy = gimple_build_assign (unshare_expr (f.lhs), rhs_tuple);

/* Get a reference to the individual vector. */
tree field = tuple_type_field (TREE_TYPE (f.lhs));
tree lhs_array
= build3 (COMPONENT_REF, TREE_TYPE (field), f.lhs, field, NULL_TREE);
tree lhs_vector = build4 (ARRAY_REF, TREE_TYPE (rhs_vector), lhs_array,
index, NULL_TREE, NULL_TREE);
gassign *update = gimple_build_assign (lhs_vector, rhs_vector);
gsi_insert_after (f.gsi, update, GSI_SAME_STMT);

return copy;
}

rtx expand (function_expander &e) const override
{
rtx dest = expand_normal (CALL_EXPR_ARG (e.exp, 0));
gcc_assert (riscv_v_ext_vector_mode_p (GET_MODE (dest)));
rtx index = expand_normal (CALL_EXPR_ARG (e.exp, 1));
rtx src = expand_normal (CALL_EXPR_ARG (e.exp, 2));
poly_int64 offset = INTVAL (index) * GET_MODE_SIZE (GET_MODE (src));
Expand All @@ -1567,9 +1598,27 @@ class vget : public function_base
public:
bool apply_vl_p () const override { return false; }

gimple *fold (gimple_folder &f) const override
{
/* Fold into a normal gimple component access. */
tree rhs_tuple = gimple_call_arg (f.call, 0);
/* LMUL > 1 non-tuple vector types are not structure,
we can't use __val[index] to get the subpart. */
if (!riscv_v_ext_tuple_mode_p (TYPE_MODE (TREE_TYPE (rhs_tuple))))
return NULL;
tree index = gimple_call_arg (f.call, 1);
tree field = tuple_type_field (TREE_TYPE (rhs_tuple));
tree rhs_array
= build3 (COMPONENT_REF, TREE_TYPE (field), rhs_tuple, field, NULL_TREE);
tree rhs_vector = build4 (ARRAY_REF, TREE_TYPE (f.lhs), rhs_array, index,
NULL_TREE, NULL_TREE);
return gimple_build_assign (f.lhs, rhs_vector);
}

rtx expand (function_expander &e) const override
{
rtx src = expand_normal (CALL_EXPR_ARG (e.exp, 0));
gcc_assert (riscv_v_ext_vector_mode_p (GET_MODE (src)));
rtx index = expand_normal (CALL_EXPR_ARG (e.exp, 1));
poly_int64 offset = INTVAL (index) * GET_MODE_SIZE (GET_MODE (e.target));
rtx subreg
Expand Down
4 changes: 4 additions & 0 deletions gcc/config/riscv/riscv-vector-builtins-functions.def
Original file line number Diff line number Diff line change
Expand Up @@ -533,4 +533,8 @@ DEF_RVV_FUNCTION (vget, vget, none_preds, all_v_vget_lmul2_x2_ops)
DEF_RVV_FUNCTION (vget, vget, none_preds, all_v_vget_lmul2_x4_ops)
DEF_RVV_FUNCTION (vget, vget, none_preds, all_v_vget_lmul4_x2_ops)

// Tuple types
DEF_RVV_FUNCTION (vset, vset, none_preds, all_v_vset_tuple_ops)
DEF_RVV_FUNCTION (vget, vget, none_preds, all_v_vget_tuple_ops)

#undef DEF_RVV_FUNCTION
Loading

0 comments on commit 678d6cd

Please sign in to comment.