Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SandboxIR] Implement the InsertElementInst class #102404

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Context;
class Function;
class Instruction;
class SelectInst;
class InsertElementInst;
class BranchInst;
class UnaryInstruction;
class LoadInst;
Expand Down Expand Up @@ -235,6 +236,7 @@ class Value {
friend class User; // For getting `Val`.
friend class Use; // For getting `Val`.
friend class SelectInst; // For getting `Val`.
friend class InsertElementInst; // For getting `Val`.
friend class BranchInst; // For getting `Val`.
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
Expand Down Expand Up @@ -631,6 +633,7 @@ class Instruction : public sandboxir::User {
/// returns its topmost LLVM IR instruction.
llvm::Instruction *getTopmostLLVMInstruction() const;
friend class SelectInst; // For getTopmostLLVMInstruction().
friend class InsertElementInst; // For getTopmostLLVMInstruction().
friend class BranchInst; // For getTopmostLLVMInstruction().
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
Expand Down Expand Up @@ -753,6 +756,52 @@ class SelectInst : public Instruction {
#endif
};

class InsertElementInst final : public Instruction {
/// Use Context::createInsertElementInst() instead.
InsertElementInst(llvm::Instruction *I, Context &Ctx)
: Instruction(ClassID::InsertElement, Opcode::InsertElement, I, Ctx) {}
friend class Context; // For accessing the constructor in
// create*()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
static Value *create(Value *Vec, Value *NewElt, Value *Idx,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name = "");
static Value *create(Value *Vec, Value *NewElt, Value *Idx,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name = "");
static bool classof(const Value *From) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a static bool isValidOperands() just like in llvm::InsertElementInst which would just return cast<InsertElementInst>(Val)->isValidOperands(Vec, NewElt, Idx).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, by calling llvm::InsertElementInst::isValidOperands(Vec->Val, NewElt->Val, Idx->Val). It's a static method so I can't cast<InsertElementInst>(Val).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test for this? Something like:

EXPECT_EQ(sandboxir::InsertElementInst::isValidOperands(V1, V2, V3), llvm::InsertElementInst::isValidOperands(Ctx.getValue(V1), Ctx.getValue(V2), Ctx.getValue(V3)));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, following the other comment in the test.

return From->getSubclassID() == ClassID::InsertElement;
}
static bool isValidOperands(const Value *Vec, const Value *NewElt,
const Value *Idx) {
return llvm::InsertElementInst::isValidOperands(Vec->Val, NewElt->Val,
Idx->Val);
}
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
#ifndef NDEBUG
void verify() const final {
assert(isa<llvm::InsertElementInst>(Val) && "Expected InsertElementInst");
}
friend raw_ostream &operator<<(raw_ostream &OS,
const InsertElementInst &IEI) {
IEI.dump(OS);
return OS;
}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

class BranchInst : public Instruction {
/// Use Context::createBranchInst(). Don't call the constructor directly.
BranchInst(llvm::BranchInst *BI, Context &Ctx)
Expand Down Expand Up @@ -1845,6 +1894,8 @@ class Context {

SelectInst *createSelectInst(llvm::SelectInst *SI);
friend SelectInst; // For createSelectInst()
InsertElementInst *createInsertElementInst(llvm::InsertElementInst *IEI);
friend InsertElementInst; // For createInsertElementInst()
BranchInst *createBranchInst(llvm::BranchInst *I);
friend BranchInst; // For createBranchInst()
LoadInst *createLoadInst(llvm::LoadInst *LI);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ DEF_USER(Constant, Constant)
// clang-format off
// ClassID, Opcode(s), Class
DEF_INSTR(Opaque, OP(Opaque), OpaqueInst)
DEF_INSTR(InsertElement, OP(InsertElement), InsertElementInst)
DEF_INSTR(Select, OP(Select), SelectInst)
DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
Expand Down
51 changes: 51 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,44 @@ void OpaqueInst::dump() const {
}
#endif // NDEBUG

Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
Instruction *InsertBefore, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(InsertBefore->getTopmostLLVMInstruction());
llvm::Value *NewV =
Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name);
if (auto *NewInsert = dyn_cast<llvm::InsertElementInst>(NewV))
return Ctx.createInsertElementInst(NewInsert);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

Value *InsertElementInst::create(Value *Vec, Value *NewElt, Value *Idx,
BasicBlock *InsertAtEnd, Context &Ctx,
const Twine &Name) {
auto &Builder = Ctx.getLLVMIRBuilder();
Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
llvm::Value *NewV =
Builder.CreateInsertElement(Vec->Val, NewElt->Val, Idx->Val, Name);
if (auto *NewInsert = dyn_cast<llvm::InsertElementInst>(NewV))
return Ctx.createInsertElementInst(NewInsert);
assert(isa<llvm::Constant>(NewV) && "Expected constant");
return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
}

#ifndef NDEBUG
void InsertElementInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void InsertElementInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
bool IsSigned) {
llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
Expand Down Expand Up @@ -1529,6 +1567,12 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
return It->second.get();
}
case llvm::Instruction::InsertElement: {
auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
It->second = std::unique_ptr<InsertElementInst>(
new InsertElementInst(LLVMIns, *this));
return It->second.get();
}
case llvm::Instruction::Br: {
auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
Expand Down Expand Up @@ -1626,6 +1670,13 @@ SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
return cast<SelectInst>(registerValue(std::move(NewPtr)));
}

InsertElementInst *
Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
auto NewPtr =
std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
}

BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
return cast<BranchInst>(registerValue(std::move(NewPtr)));
Expand Down
50 changes: 50 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "llvm/SandboxIR/SandboxIR.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
Expand Down Expand Up @@ -630,6 +631,55 @@ define void @foo(i1 %c0, i8 %v0, i8 %v1, i1 %c1) {
}
}

TEST_F(SandboxIRTest, InsertElementInst) {
parseIR(C, R"IR(
define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) {
%ins0 = insertelement <2 x i8> poison, i8 %v0, i32 0
%ins1 = insertelement <2 x i8> %ins0, i8 %v1, i32 1
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto *ArgVec = F.getArg(2);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Ins0 = cast<sandboxir::InsertElementInst>(&*It++);
auto *Ins1 = cast<sandboxir::InsertElementInst>(&*It++);
auto *Ret = &*It++;

EXPECT_EQ(Ins0->getOpcode(), sandboxir::Instruction::Opcode::InsertElement);
EXPECT_EQ(Ins0->getOperand(1), Arg0);
EXPECT_EQ(Ins1->getOperand(1), Arg1);
EXPECT_EQ(Ins1->getOperand(0), Ins0);
auto *Poison = Ins0->getOperand(0);
auto *Idx = Ins0->getOperand(2);
auto *NewI1 =
cast<sandboxir::InsertElementInst>(sandboxir::InsertElementInst::create(
Poison, Arg0, Idx, Ret, Ctx, "NewIns1"));
EXPECT_EQ(NewI1->getOperand(0), Poison);
EXPECT_EQ(NewI1->getNextNode(), Ret);

auto *NewI2 =
cast<sandboxir::InsertElementInst>(sandboxir::InsertElementInst::create(
Poison, Arg0, Idx, BB, Ctx, "NewIns2"));
EXPECT_EQ(NewI2->getPrevNode(), Ret);

auto *LLVMArg0 = LLVMF.getArg(0);
auto *LLVMArgVec = LLVMF.getArg(2);
auto *Zero = sandboxir::Constant::createInt(Type::getInt8Ty(C), 0, Ctx);
auto *LLVMZero = llvm::ConstantInt::get(Type::getInt8Ty(C), 0);
EXPECT_EQ(
sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero),
llvm::InsertElementInst::isValidOperands(LLVMArgVec, LLVMArg0, LLVMZero));
EXPECT_EQ(
sandboxir::InsertElementInst::isValidOperands(Arg0, ArgVec, Zero),
llvm::InsertElementInst::isValidOperands(LLVMArg0, LLVMArgVec, LLVMZero));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the isValidOperands() check I would add an extra %vec argument to define void @foo(i8 %v0, i8 %v1, <2 x i8> %vec) auto *ArgVec = F.getArg(ArgIdx++);

Then we would need to get both the sandboxir and the llvm arguments:

unsigned LLVMArgIdx = 0;
auto *LLVMArg0 = LLVMF.getArg(LLVMArgIdx++);
auto *LLVMArg1 = LLVMF.getArg(LLVMArgIdx++);
auto *LLVMArgVec = LLVMF.getArg(LLVMArgIdx++);
...
auto *Arg1 = F.getArg(ArgIdx++);
auto *ArgVec = F.getArg(ArgIdx++);

Then:

auto *Zero = sandboxir::Constant::createInt(Type::getInt8Ty(C), 0, Ctx);
auto *LLVMZero = llvm::Constant::getZero(Type::getInt8Ty(C));

Finally I would add a check like:

EXPECT_EQ(sandboxir::InsertElementInst::isValidOperands(ArgVec, Arg0, Zero), llvm::InsertElementInst::isValidOperands(LLVMArgVec, LLVMArg0, LLVMZero));

To make sure we exercise this even more we can also try swapping the arguments:

EXPECT_EQ(sandboxir::InsertElementInst::isValidOperands(Arg0, ArgVec, Zero), llvm::InsertElementInst::isValidOperands(LLVMArg0, LLVMArgVec, LLVMZero));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, with some slight variation. In particular, I've gotten rid of the ArgIdx variable (and haven't added an LLVMArgIdx either) because the API allows random access and I don't think using a counter to simulate an iterator gains us anything.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine with me. The reason I prefer a variable is because as I am modifying the test I may add a new N'th argument. If I am using hard-coded numbers in F.getArg() then I will have to renumber some of them, but If I am using a counter F.getArg(Idx++) it just works.


TEST_F(SandboxIRTest, BranchInst) {
parseIR(C, R"IR(
define void @foo(i1 %cond0, i1 %cond2) {
Expand Down
Loading