Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISEL] Add support for lowerFormalArguments that contain scalable vector types #70882

Merged
merged 10 commits into from
Nov 14, 2023
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
if (PartLLT.isVector() == LLTy.isVector() &&
PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() &&
(!PartLLT.isVector() ||
PartLLT.getNumElements() == LLTy.getNumElements()) &&
PartLLT.getElementCount() == LLTy.getElementCount()) &&
OrigRegs.size() == 1 && Regs.size() == 1) {
Register SrcReg = Regs[0];

Expand Down Expand Up @@ -406,6 +406,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
// If PartLLT is a mismatched vector in both number of elements and element
// size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to
// have the same elt type, i.e. v4s32.
// TODO: Extend this coersion to element multiples other than just 2.
if (PartLLT.getSizeInBits() > LLTy.getSizeInBits() &&
PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 &&
Regs.size() == 1) {
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,16 +1065,16 @@ void MachineIRBuilder::validateTruncExt(const LLT DstTy, const LLT SrcTy,
#ifndef NDEBUG
if (DstTy.isVector()) {
assert(SrcTy.isVector() && "mismatched cast between vector and non-vector");
assert(SrcTy.getNumElements() == DstTy.getNumElements() &&
assert(SrcTy.getElementCount() == DstTy.getElementCount() &&
"different number of elements in a trunc/ext");
} else
assert(DstTy.isScalar() && SrcTy.isScalar() && "invalid extend/trunc");

if (IsExtend)
assert(DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
assert(TypeSize::isKnownGT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
"invalid narrowing extend");
else
assert(DstTy.getSizeInBits() < SrcTy.getSizeInBits() &&
assert(TypeSize::isKnownLT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
"invalid widening trunc");
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/LowLevelType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace llvm;

LLT::LLT(MVT VT) {
if (VT.isVector()) {
bool asVector = VT.getVectorMinNumElements() > 1;
bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector();
init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
/*AddressSpace=*/0);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/MachineVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ bool MachineVerifier::verifyVectorElementMatch(LLT Ty0, LLT Ty1,
return false;
}

if (Ty0.isVector() && Ty0.getNumElements() != Ty1.getNumElements()) {
if (Ty0.isVector() && Ty0.getElementCount() != Ty1.getElementCount()) {
report("operand types must preserve number of vector elements", MI);
return false;
}
Expand Down
37 changes: 35 additions & 2 deletions llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "RISCVCallLowering.h"
#include "RISCVISelLowering.h"
#include "RISCVMachineFunctionInfo.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
Expand Down Expand Up @@ -185,6 +186,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
const DataLayout &DL = MF.getDataLayout();
const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();

if (LocVT.isScalableVector())
MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();

if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
*Subtarget.getTargetLowering(),
Expand Down Expand Up @@ -301,8 +305,31 @@ struct RISCVCallReturnHandler : public RISCVIncomingValueHandler {
RISCVCallLowering::RISCVCallLowering(const RISCVTargetLowering &TLI)
: CallLowering(&TLI) {}

/// Return true if scalable vector with ScalarTy is legal for lowering.
static bool isLegalElementTypeForRVV(Type *EltTy,
const RISCVSubtarget &Subtarget) {
if (EltTy->isPointerTy())
return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true;
if (EltTy->isIntegerTy(1) || EltTy->isIntegerTy(8) ||
EltTy->isIntegerTy(16) || EltTy->isIntegerTy(32))
return true;
if (EltTy->isIntegerTy(64))
return Subtarget.hasVInstructionsI64();
if (EltTy->isHalfTy())
return Subtarget.hasVInstructionsF16();
if (EltTy->isBFloatTy())
return Subtarget.hasVInstructionsBF16();
if (EltTy->isFloatTy())
return Subtarget.hasVInstructionsF32();
if (EltTy->isDoubleTy())
return Subtarget.hasVInstructionsF64();
return false;
}

// TODO: Support all argument types.
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget) {
// TODO: Remove IsLowerArgs argument by adding support for vectors in lowerCall.
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget,
bool IsLowerArgs = false) {
// TODO: Integers larger than 2*XLen are passed indirectly which is not
// supported yet.
if (T->isIntegerTy())
Expand All @@ -311,6 +338,11 @@ static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget) {
return true;
if (T->isPointerTy())
return true;
// TODO: Support fixed vector types.
if (IsLowerArgs && T->isVectorTy() && Subtarget.hasVInstructions() &&
T->isScalableTy() &&
isLegalElementTypeForRVV(T->getScalarType(), Subtarget))
return true;
return false;
}

Expand Down Expand Up @@ -398,7 +430,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
const RISCVSubtarget &Subtarget =
MIRBuilder.getMF().getSubtarget<RISCVSubtarget>();
for (auto &Arg : F.args()) {
if (!isSupportedArgumentType(Arg.getType(), Subtarget))
if (!isSupportedArgumentType(Arg.getType(), Subtarget,
/*IsLowerArgs=*/true))
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ declare <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
<vscale x 1 x i8>,
i64)

; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to lower arguments{{.*}}scalable_arg
; FALLBACK_WITH_REPORT_ERR: <unknown>:0:0: unable to translate instruction: call:
; FALLBACK-WITH-REPORT-OUT-LABEL: scalable_arg
define <vscale x 1 x i8> @scalable_arg(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i64 %2) nounwind {
entry:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; RUN: not --crash llc -mtriple=riscv32 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
; RUN: not --crash llc -mtriple=riscv64 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s

; The purpose of this test is to show that the compiler throws an error when
; there is no support for bf16 vectors. If the compiler did not throw an error,
; then it will try to scalarize the argument to an s32, which may drop elements.
define void @test_args_nxv1bf16(<vscale x 1 x bfloat> %a) {
entry:
ret void
}

; CHECK: LLVM ERROR: unable to lower arguments: ptr (in function: test_args_nxv1bf16)


Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; RUN: not --crash llc -mtriple=riscv32 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
; RUN: not --crash llc -mtriple=riscv64 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s

; The purpose of this test is to show that the compiler throws an error when
; there is no support for f16 vectors. If the compiler did not throw an error,
; then it will try to scalarize the argument to an s32, which may drop elements.
define void @test_args_nxv1f16(<vscale x 1 x half> %a) {
entry:
ret void
}

; CHECK: LLVM ERROR: unable to lower arguments: ptr (in function: test_args_nxv1f16)


Loading