Skip to content

Commit

Permalink
fixup! handle struct and minor fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
4vtomat committed Jan 6, 2025
1 parent 6b76dbc commit 68d2338
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 19 deletions.
27 changes: 11 additions & 16 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3239,24 +3239,19 @@ void CodeGenFunction::EmitFunctionProlog(const CGFunctionInfo &FI,
}
}

llvm::StructType *STy =
dyn_cast<llvm::StructType>(ArgI.getCoerceToType());
if (ArgI.isDirect() && !ArgI.getCanBeFlattened() && STy &&
STy->getNumElements() > 1) {
[[maybe_unused]] llvm::TypeSize StructSize =
CGM.getDataLayout().getTypeAllocSize(STy);
[[maybe_unused]] llvm::TypeSize PtrElementSize =
CGM.getDataLayout().getTypeAllocSize(ConvertTypeForMem(Ty));
if (STy->containsHomogeneousScalableVectorTypes()) {
assert(StructSize == PtrElementSize &&
"Only allow non-fractional movement of structure with"
"homogeneous scalable vector type");

ArgVals.push_back(ParamValue::forDirect(AI));
break;
}
// Struct of fixed-length vectors and struct of array of fixed-length
// vector in VLS calling convention are coerced to vector tuple
// type(represented as TargetExtType) and scalable vector type
// respectively, they're no longer handled as struct.
if (ArgI.isDirect() && isa<llvm::StructType>(ConvertType(Ty)) &&
(isa<llvm::TargetExtType>(ArgI.getCoerceToType()) ||
isa<llvm::ScalableVectorType>(ArgI.getCoerceToType()))) {
ArgVals.push_back(ParamValue::forDirect(AI));
break;
}

llvm::StructType *STy =
dyn_cast<llvm::StructType>(ArgI.getCoerceToType());
Address Alloca = CreateMemTemp(Ty, getContext().getDeclAlign(Arg),
Arg->getName());

Expand Down
160 changes: 157 additions & 3 deletions clang/lib/CodeGen/Targets/RISCV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class RISCVABIInfo : public DefaultABIInfo {
llvm::Type *&Field2Ty,
CharUnits &Field2Off) const;

bool detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
llvm::Type *&VLSType) const;

public:
RISCVABIInfo(CodeGen::CodeGenTypes &CGT, unsigned XLen, unsigned FLen,
bool EABI)
Expand Down Expand Up @@ -361,6 +364,149 @@ ABIArgInfo RISCVABIInfo::coerceAndExpandFPCCEligibleStruct(
return ABIArgInfo::getCoerceAndExpand(CoerceToType, UnpaddedCoerceToType);
}

bool RISCVABIInfo::detectVLSCCEligibleStruct(QualType Ty, unsigned ABIVLen,
llvm::Type *&VLSType) const {
// No riscv_vls_cc attribute.
if (ABIVLen == 1)
return false;

// Legal struct for VLS calling convention should fulfill following rules:
// 1. Struct element should be either "homogeneous fixed-length vectors" or "a
// fixed-length vector array".
// 2. Number of struct elements or array elements should be power of 2.
// 3. Total number of vector registers needed should not exceed 8.
//
// Examples: Assume ABI_VLEN = 128.
// These are legal structs:
// a. Structs with 1, 2, 4 or 8 "same" fixed-length vectors, e.g.
// struct {
// __attribute__((vector_size(16))) int a;
// __attribute__((vector_size(16))) int b;
// }
//
// b. Structs with "single" fixed-length vector array with lengh 1, 2, 4
// or 8, e.g.
// struct {
// __attribute__((vector_size(16))) int a[2];
// }
// These are illegal structs:
// a. Structs with 3 fixed-length vectors, e.g.
// struct {
// __attribute__((vector_size(16))) int a;
// __attribute__((vector_size(16))) int b;
// __attribute__((vector_size(16))) int c;
// }
//
// b. Structs with "multiple" fixed-length vector array, e.g.
// struct {
// __attribute__((vector_size(16))) int a[2];
// __attribute__((vector_size(16))) int b[2];
// }
//
// c. Vector registers needed exceeds 8, e.g.
// struct {
// // Registers needed for single fixed-length element:
// // 64 * 8 / ABI_VLEN = 4
// __attribute__((vector_size(64))) int a;
// __attribute__((vector_size(64))) int b;
// __attribute__((vector_size(64))) int c;
// __attribute__((vector_size(64))) int d;
// }
//
// Struct of 1 fixed-length vector is passed as a scalable vector.
// Struct of >1 fixed-length vectors are passed as vector tuple.
// Struct of 1 array of fixed-length vectors is passed as a scalable vector.
// Otherwise, pass the struct indirectly.

if (llvm::StructType *STy = dyn_cast<llvm::StructType>(CGT.ConvertType(Ty))) {
int NumElts = STy->getStructNumElements();
if (NumElts > 8 || !llvm::isPowerOf2_32(NumElts))
return false;

auto *FirstEltTy = STy->getElementType(0);
if (!STy->containsHomogeneousTypes())
return false;

// Check structure of fixed-length vectors and turn them into vector tuple
// type if legal.
if (auto *FixedVecTy = dyn_cast<llvm::FixedVectorType>(FirstEltTy)) {
if (NumElts == 1) {
// Handle single fixed-length vector.
VLSType = llvm::ScalableVectorType::get(
FixedVecTy->getElementType(),
llvm::divideCeil(FixedVecTy->getNumElements() *
llvm::RISCV::RVVBitsPerBlock,
ABIVLen));
// Check registers needed <= 8.
return llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits(),
ABIVLen) <= 8;
}
// LMUL
// = fixed-length vector size / ABIVLen
// = 8 * I8EltCount / RVVBitsPerBlock
// =>
// I8EltCount
// = (fixed-length vector size * RVVBitsPerBlock) / (ABIVLen * 8)
unsigned I8EltCount = llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits() *
llvm::RISCV::RVVBitsPerBlock,
ABIVLen * 8);
VLSType = llvm::TargetExtType::get(
getVMContext(), "riscv.vector.tuple",
llvm::ScalableVectorType::get(llvm::Type::getInt8Ty(getVMContext()),
I8EltCount),
NumElts);
// Check registers needed <= 8.
return NumElts *
llvm::divideCeil(
FixedVecTy->getNumElements() *
FixedVecTy->getElementType()->getScalarSizeInBits(),
ABIVLen) <=
8;
}

// If elements are not fixed-length vectors, it should be an array.
if (NumElts != 1)
return false;

// Check array of fixed-length vector and turn it into scalable vector type
// if legal.
if (auto *ArrTy = dyn_cast<llvm::ArrayType>(FirstEltTy)) {
int NumArrElt = ArrTy->getNumElements();
if (NumArrElt > 8 || !llvm::isPowerOf2_32(NumArrElt))
return false;

auto *ArrEltTy = dyn_cast<llvm::FixedVectorType>(ArrTy->getElementType());
if (!ArrEltTy)
return false;

// LMUL
// = NumArrElt * fixed-length vector size / ABIVLen
// = fixed-length vector elt size * ScalVecNumElts / RVVBitsPerBlock
// =>
// ScalVecNumElts
// = (NumArrElt * fixed-length vector size * RVVBitsPerBlock) /
// (ABIVLen * fixed-length vector elt size)
// = NumArrElt * num fixed-length vector elt * RVVBitsPerBlock /
// ABIVLen
unsigned ScalVecNumElts = llvm::divideCeil(
NumArrElt * ArrEltTy->getNumElements() * llvm::RISCV::RVVBitsPerBlock,
ABIVLen);
VLSType = llvm::ScalableVectorType::get(ArrEltTy->getElementType(),
ScalVecNumElts);
// Check registers needed <= 8.
return llvm::divideCeil(
ScalVecNumElts *
ArrEltTy->getElementType()->getScalarSizeInBits(),
llvm::RISCV::RVVBitsPerBlock) <= 8;
}
}
return false;
}

// Fixed-length RVV vectors are represented as scalable vectors in function
// args/return and must be coerced from fixed vectors.
ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty,
Expand Down Expand Up @@ -410,11 +556,13 @@ ABIArgInfo RISCVABIInfo::coerceVLSVector(QualType Ty,
(EltType->isBFloatTy() && !TI.hasFeature("zvfbfmin")) ||
(EltType->isFloatTy() && !TI.hasFeature("zve32f")) ||
(EltType->isDoubleTy() && !TI.hasFeature("zve64d")) ||
(EltType->isIntegerTy(64) && !TI.hasFeature("zve64x")) ||
EltType->isIntegerTy(128)) {
EltType->isIntegerTy(128))
EltType =
llvm::Type::getIntNTy(getVMContext(), EltType->getScalarSizeInBits());
}

// Check registers needed <= 8.
if ((EltType->getScalarSizeInBits() * NumElts / ABIVLen) > 8)
return getNaturalAlignIndirect(Ty, /*ByVal=*/false);

// Generic vector
// The number of elements needs to be at least 1.
Expand Down Expand Up @@ -485,6 +633,12 @@ ABIArgInfo RISCVABIInfo::classifyArgumentType(QualType Ty, bool IsFixed,
}
}

if (IsFixed && Ty->isStructureOrClassType()) {
llvm::Type *VLSType = nullptr;
if (detectVLSCCEligibleStruct(Ty, ABIVLen, VLSType))
return ABIArgInfo::getDirect(VLSType);
}

uint64_t NeededAlign = getContext().getTypeAlign(Ty);
// Determine the number of GPRs needed to pass the current argument
// according to the ABI. 2*XLen-aligned varargs are passed in "aligned"
Expand Down
108 changes: 108 additions & 0 deletions clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,111 @@ void __attribute__((riscv_vls_cc(1024))) test_vls_least_element(__attribute__((v

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_vls_least_element_c23(<vscale x 1 x i32> noundef %arg.coerce)
[[riscv::vls_cc(1024)]] void test_vls_least_element_c23(__attribute__((vector_size(8))) int arg) {}


struct st_i32x4{
__attribute__((vector_size(16))) int i32;
};

struct st_i32x4_arr1{
__attribute__((vector_size(16))) int i32[1];
};

struct st_i32x4_arr4{
__attribute__((vector_size(16))) int i32[4];
};

struct st_i32x4_arr8{
__attribute__((vector_size(16))) int i32[8];
};


struct st_i32x4x2{
__attribute__((vector_size(16))) int i32_1;
__attribute__((vector_size(16))) int i32_2;
};

struct st_i32x8x2{
__attribute__((vector_size(32))) int i32_1;
__attribute__((vector_size(32))) int i32_2;
};

struct st_i32x64x2{
__attribute__((vector_size(256))) int i32_1;
__attribute__((vector_size(256))) int i32_2;
};

struct st_i32x4x8{
__attribute__((vector_size(16))) int i32_1;
__attribute__((vector_size(16))) int i32_2;
__attribute__((vector_size(16))) int i32_3;
__attribute__((vector_size(16))) int i32_4;
__attribute__((vector_size(16))) int i32_5;
__attribute__((vector_size(16))) int i32_6;
__attribute__((vector_size(16))) int i32_7;
__attribute__((vector_size(16))) int i32_8;
};

struct st_i32x4x9{
__attribute__((vector_size(16))) int i32_1;
__attribute__((vector_size(16))) int i32_2;
__attribute__((vector_size(16))) int i32_3;
__attribute__((vector_size(16))) int i32_4;
__attribute__((vector_size(16))) int i32_5;
__attribute__((vector_size(16))) int i32_6;
__attribute__((vector_size(16))) int i32_7;
__attribute__((vector_size(16))) int i32_8;
__attribute__((vector_size(16))) int i32_9;
};

typedef int __attribute__((vector_size(256))) int32x64_t;

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_too_large(ptr noundef %0)
void __attribute__((riscv_vls_cc)) test_too_large(int32x64_t arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_too_large_256(<vscale x 16 x i32> noundef %arg.coerce)
void __attribute__((riscv_vls_cc(256))) test_too_large_256(int32x64_t arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4(<vscale x 2 x i32> %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4(struct st_i32x4 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_256(<vscale x 1 x i32> %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_256(struct st_i32x4 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr1(<vscale x 2 x i32> %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr1(struct st_i32x4_arr1 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr1_256(<vscale x 1 x i32> %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr1_256(struct st_i32x4_arr1 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr4(<vscale x 8 x i32> %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr4(struct st_i32x4_arr4 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr4_256(<vscale x 4 x i32> %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr4_256(struct st_i32x4_arr4 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr8(<vscale x 16 x i32> %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4_arr8(struct st_i32x4_arr8 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4_arr8_256(<vscale x 8 x i32> %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4_arr8_256(struct st_i32x4_arr8 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x2(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4x2(struct st_i32x4x2 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x2_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 2) %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4x2_256(struct st_i32x4x2 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x8x2(target("riscv.vector.tuple", <vscale x 16 x i8>, 2) %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x8x2(struct st_i32x8x2 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x8x2_256(target("riscv.vector.tuple", <vscale x 8 x i8>, 2) %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x8x2_256(struct st_i32x8x2 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x64x2(ptr noundef %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x64x2(struct st_i32x64x2 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x64x2_256(ptr noundef %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x64x2_256(struct st_i32x64x2 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x8(target("riscv.vector.tuple", <vscale x 8 x i8>, 8) %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4x8(struct st_i32x4x8 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x8_256(target("riscv.vector.tuple", <vscale x 4 x i8>, 8) %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4x8_256(struct st_i32x4x8 arg) {}

// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x9(ptr noundef %arg)
void __attribute__((riscv_vls_cc)) test_st_i32x4x9(struct st_i32x4x9 arg) {}
// CHECK-LLVM: define dso_local riscv_vls_cc void @test_st_i32x4x9_256(ptr noundef %arg)
void __attribute__((riscv_vls_cc(256))) test_st_i32x4x9_256(struct st_i32x4x9 arg) {}
Loading

0 comments on commit 68d2338

Please sign in to comment.