Skip to content

Commit

Permalink
[SandboxIR] Implement the InsertElementInst class (#102404)
Browse files Browse the repository at this point in the history
Heavily based on work by @vporpo.
  • Loading branch information
slackito authored Aug 9, 2024
1 parent a21cf56 commit 66d8735
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
51 changes: 51 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Context;
class Function;
class Instruction;
class SelectInst;
class InsertElementInst;
class BranchInst;
class UnaryInstruction;
class LoadInst;
Expand Down Expand Up @@ -235,6 +236,7 @@ class Value {
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class InsertElementInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
Expand Down Expand Up @@ -631,6 +633,7 @@ class Instruction : public sandboxir::User {
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class InsertElementInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -753,6 +756,52 @@ class SelectInst : public Instruction {
#endif
};

class InsertElementInst final : public Instruction {
/// Use Context::createInsertElementInst() instead.
InsertElementInst(llvm::Instruction *I, Context &Ctx)
: Instruction(ClassID::InsertElement, Opcode::InsertElement, I, Ctx) {}
friend class Context; // For accessing the constructor in
// create*()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
static Value *create(Value *Vec, Value *NewElt, Value *Idx,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static Value *create(Value *Vec, Value *NewElt, Value *Idx,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name = "");
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::InsertElement;
}
static bool isValidOperands(const Value *Vec, const Value *NewElt,
const Value *Idx) {
return llvm::InsertElementInst::isValidOperands(Vec->Val, NewElt->Val,
Idx->Val);
}
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::InsertElementInst>(Val) && "Expected InsertElementInst");
}
friend raw_ostream &operator<<(raw_ostream &OS,
const InsertElementInst &IEI) {
IEI.dump(OS);
return OS;
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

class BranchInst : public Instruction {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
Expand Down Expand Up @@ -1845,6 +1894,8 @@ class Context {

SelectInst *createSelectInst(llvm::SelectInst *SI);
friend SelectInst; // For createSelectInst()
InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI);
friend InsertElementInst; // For createInsertElementInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ DEF_USER(Constant, Constant)
// clang-format off
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
Expand Down
51 changes: 51 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,44 @@ void OpaqueInst::dump() const {
}
#endif // NDEBUG

Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
llvm::Value *NewV =
Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name);
if (auto *NewInsert = dyn_cast<llvm::InsertElementInst>(NewV))
return Ctx.createInsertElementInst(NewInsert);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::Value *NewV =
Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name);
if (auto *NewInsert = dyn_cast<llvm::InsertElementInst>(NewV))
return Ctx.createInsertElementInst(NewInsert);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

#ifndef NDEBUG
void InsertElementInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void InsertElementInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned) {
llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
Expand Down Expand Up @@ -1551,6 +1589,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
case llvm::Instruction::InsertElement: {
auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
It->second = std::unique_ptr<InsertElementInst>(
new InsertElementInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::Br: {
auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
Expand Down Expand Up @@ -1648,6 +1692,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}

InsertElementInst *
Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
auto NewPtr =
std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
}

BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
return cast<BranchInst>(registerValue(std::move(NewPtr)));
Expand Down
50 changes: 50 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "llvm/SandboxIR/SandboxIR.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
Expand Down Expand Up @@ -630,6 +631,55 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
}
}

TEST_F(SandboxIRTest, InsertElementInst) {
parseIR(C, R"IR(
define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
%ins0 = insertelement <2 x i8> poison, i8 %v0, i32 0
%ins1 = insertelement <2 x i8> %ins0, i8 %v1, i32 1
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto *ArgVec = F.getArg(2);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Ins0 = cast<sandboxir::InsertElementInst>(&*It++);
auto *Ins1 = cast<sandboxir::InsertElementInst>(&*It++);
auto *Ret = &*It++;

EXPECT_EQ(Ins0->getOpcode(), sandboxir::Instruction::Opcode::InsertElement);
EXPECT_EQ(Ins0->getOperand(1), Arg0);
EXPECT_EQ(Ins1->getOperand(1), Arg1);
EXPECT_EQ(Ins1->getOperand(0), Ins0);
auto *Poison = Ins0->getOperand(0);
auto *Idx = Ins0->getOperand(2);
auto *NewI1 =
cast<sandboxir::InsertElementInst>(sandboxir::InsertElementInst::create(
Poison, Arg0, Idx, Ret, Ctx, "NewIns1"));
EXPECT_EQ(NewI1->getOperand(0), Poison);
EXPECT_EQ(NewI1->getNextNode(), Ret);

auto *NewI2 =
cast<sandboxir::InsertElementInst>(sandboxir::InsertElementInst::create(
Poison, Arg0, Idx, BB, Ctx, "NewIns2"));
EXPECT_EQ(NewI2->getPrevNode(), Ret);

auto *LLVMArg0 = LLVMF.getArg(0);
auto *LLVMArgVec = LLVMF.getArg(2);
auto *Zero = sandboxir::Constant::createInt(Type::getInt8Ty(C), 0, Ctx);
auto *LLVMZero = llvm::ConstantInt::get(Type::getInt8Ty(C), 0);
EXPECT_EQ(
sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero),
llvm::InsertElementInst::isValidOperands(LLVMArgVec, LLVMArg0, LLVMZero));
EXPECT_EQ(
sandboxir::InsertElementInst::isValidOperands(Arg0, ArgVec, Zero),
llvm::InsertElementInst::isValidOperands(LLVMArg0, LLVMArgVec, LLVMZero));
}

TEST_F(SandboxIRTest, BranchInst) {
parseIR(C, R"IR(
define void @foo(i1 %cond0, i1 %cond2) {
Expand Down

0 comments on commit 66d8735

Please sign in to comment.