diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h index 44aee4e4a5b46e..ec141c249fb21e 100644 --- a/llvm/include/llvm/SandboxIR/Type.h +++ b/llvm/include/llvm/SandboxIR/Type.h @@ -25,6 +25,7 @@ class Context; // Forward declare friend classes for MSVC. class PointerType; class VectorType; +class FixedVectorType; class IntegerType; class FunctionType; class ArrayType; @@ -41,6 +42,7 @@ class Type { friend class ArrayType; // For LLVMTy. friend class StructType; // For LLVMTy. friend class VectorType; // For LLVMTy. + friend class FixedVectorType; // For LLVMTy. friend class PointerType; // For LLVMTy. friend class FunctionType; // For LLVMTy. friend class IntegerType; // For LLVMTy. @@ -344,6 +346,50 @@ class VectorType : public Type { } }; +class FixedVectorType : public VectorType { +public: + static FixedVectorType *get(Type *ElementType, unsigned NumElts); + + static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) { + return get(ElementType, FVTy->getNumElements()); + } + + static FixedVectorType *getInteger(FixedVectorType *VTy) { + return cast(VectorType::getInteger(VTy)); + } + + static FixedVectorType *getExtendedElementVectorType(FixedVectorType *VTy) { + return cast(VectorType::getExtendedElementVectorType(VTy)); + } + + static FixedVectorType *getTruncatedElementVectorType(FixedVectorType *VTy) { + return cast( + VectorType::getTruncatedElementVectorType(VTy)); + } + + static FixedVectorType *getSubdividedVectorType(FixedVectorType *VTy, + int NumSubdivs) { + return cast( + VectorType::getSubdividedVectorType(VTy, NumSubdivs)); + } + + static FixedVectorType *getHalfElementsVectorType(FixedVectorType *VTy) { + return cast(VectorType::getHalfElementsVectorType(VTy)); + } + + static FixedVectorType *getDoubleElementsVectorType(FixedVectorType *VTy) { + return cast(VectorType::getDoubleElementsVectorType(VTy)); + } + + static bool classof(const Type *T) { + return isa(T->LLVMTy); + } + + unsigned getNumElements() const { + return cast(LLVMTy)->getNumElements(); + } +}; + class FunctionType : public Type { public: // TODO: add missing functions diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp index bf9f02e2ba3111..26aa8b3743084c 100644 --- a/llvm/lib/SandboxIR/Type.cpp +++ b/llvm/lib/SandboxIR/Type.cpp @@ -103,6 +103,11 @@ bool VectorType::isValidElementType(Type *ElemTy) { return llvm::VectorType::isValidElementType(ElemTy->LLVMTy); } +FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) { + return cast(ElementType->getContext().getType( + llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts))); +} + IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) { return cast( Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits))); diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp index e4f9235c1ef3ca..3564ae66830147 100644 --- a/llvm/unittests/SandboxIR/TypesTest.cpp +++ b/llvm/unittests/SandboxIR/TypesTest.cpp @@ -323,6 +323,64 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) { EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy)); } +TEST_F(SandboxTypeTest, FixedVectorType) { + parseIR(C, R"IR( +define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) { + ret void +} +)IR"); + llvm::Function *LLVMF = &*M->getFunction("foo"); + sandboxir::Context Ctx(C); + auto *F = Ctx.createFunction(LLVMF); + // Check classof(), creation, accessors + auto *Vec4i16Ty = cast(F->getArg(0)->getType()); + EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16)); + EXPECT_EQ(Vec4i16Ty->getElementCount(), ElementCount::getFixed(4)); + + // get(ElementType, NumElements) + EXPECT_EQ( + sandboxir::FixedVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4), + F->getArg(0)->getType()); + // get(ElementType, Other) + EXPECT_EQ(sandboxir::FixedVectorType::get( + sandboxir::Type::getInt16Ty(Ctx), + cast(F->getArg(0)->getType())), + F->getArg(0)->getType()); + auto *Vec4FTy = cast(F->getArg(1)->getType()); + EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy()); + // getInteger + auto *Vec4i32Ty = sandboxir::FixedVectorType::getInteger(Vec4FTy); + EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32)); + EXPECT_EQ(Vec4i32Ty->getElementCount(), Vec4FTy->getElementCount()); + // getExtendedElementCountVectorType + auto *Vec4i64Ty = + sandboxir::FixedVectorType::getExtendedElementVectorType(Vec4i16Ty); + EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32)); + EXPECT_EQ(Vec4i64Ty->getElementCount(), Vec4i16Ty->getElementCount()); + // getTruncatedElementVectorType + auto *Vec4i8Ty = + sandboxir::FixedVectorType::getTruncatedElementVectorType(Vec4i16Ty); + EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8)); + EXPECT_EQ(Vec4i8Ty->getElementCount(), Vec4i8Ty->getElementCount()); + // getSubdividedVectorType + auto *Vec8i8Ty = + sandboxir::FixedVectorType::getSubdividedVectorType(Vec4i16Ty, 1); + EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8)); + EXPECT_EQ(Vec8i8Ty->getElementCount(), ElementCount::getFixed(8)); + // getNumElements + EXPECT_EQ(Vec8i8Ty->getNumElements(), 8u); + // getHalfElementsVectorType + auto *Vec2i16Ty = + sandboxir::FixedVectorType::getHalfElementsVectorType(Vec4i16Ty); + EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16)); + EXPECT_EQ(Vec2i16Ty->getElementCount(), ElementCount::getFixed(2)); + // getDoubleElementsVectorType + auto *Vec8i16Ty = + sandboxir::FixedVectorType::getDoubleElementsVectorType(Vec4i16Ty); + EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16)); + EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8)); +} + TEST_F(SandboxTypeTest, FunctionType) { parseIR(C, R"IR( define void @foo() {