Skip to content

Commit

Permalink
[SandboxVec] Legality boilerplate (#108650)
Browse files Browse the repository at this point in the history
This patch adds the basic API for the Legality component of the
vectorizer. It also adds some very basic code in the bottom-up
vectorizer that uses the API.
  • Loading branch information
vporpo authored Sep 18, 2024
1 parent 785624b commit 42c5a30
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- Legality.h -----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Legality checks for the Sandbox Vectorizer.
//

#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H

#include "llvm/SandboxIR/SandboxIR.h"

namespace llvm::sandboxir {

class LegalityAnalysis;

enum class LegalityResultID {
Widen, ///> Vectorize by combining scalars to a vector.
};

/// The legality outcome is represented by a class rather than an enum class
/// because in some cases the legality checks are expensive and look for a
/// particular instruction that can be passed along to the vectorizer to avoid
/// repeating the same expensive computation.
class LegalityResult {
protected:
LegalityResultID ID;
/// Only Legality can create LegalityResults.
LegalityResult(LegalityResultID ID) : ID(ID) {}
friend class LegalityAnalysis;

public:
LegalityResultID getSubclassID() const { return ID; }
};

class Widen final : public LegalityResult {
friend class LegalityAnalysis;
Widen() : LegalityResult(LegalityResultID::Widen) {}

public:
static bool classof(const LegalityResult *From) {
return From->getSubclassID() == LegalityResultID::Widen;
}
};

/// Performs the legality analysis and returns a LegalityResult object.
class LegalityAnalysis {
public:
LegalityAnalysis() = default;
LegalityResult canVectorize(ArrayRef<Value *> Bndl) {
// TODO: For now everything is legal.
return Widen();
}
};

} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_PASSES_BOTTOMUPVEC_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/SandboxIR/Pass.h"
#include "llvm/SandboxIR/SandboxIR.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"

namespace llvm::sandboxir {

class BottomUpVec final : public FunctionPass {
bool Change = false;
LegalityAnalysis Legality;
void vectorizeRec(ArrayRef<Value *> Bndl);
void tryVectorize(ArrayRef<Value *> Seeds);

public:
BottomUpVec() : FunctionPass("bottom-up-vec") {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,58 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
#include "llvm/ADT/SmallVector.h"

using namespace llvm::sandboxir;

bool BottomUpVec::runOnFunction(Function &F) { return false; }
namespace llvm::sandboxir {
// TODO: This is a temporary function that returns some seeds.
// Replace this with SeedCollector's function when it lands.
static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) {
llvm::SmallVector<Value *, 4> Seeds;
for (auto &I : BB)
if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
Seeds.push_back(SI);
return Seeds;
}

static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
unsigned OpIdx) {
SmallVector<Value *, 4> Operands;
for (Value *BndlV : Bndl) {
auto *BndlI = cast<Instruction>(BndlV);
Operands.push_back(BndlI->getOperand(OpIdx));
}
return Operands;
}

} // namespace llvm::sandboxir

void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
auto LegalityRes = Legality.canVectorize(Bndl);
switch (LegalityRes.getSubclassID()) {
case LegalityResultID::Widen: {
auto *I = cast<Instruction>(Bndl[0]);
for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
auto OperandBndl = getOperand(Bndl, OpIdx);
vectorizeRec(OperandBndl);
}
break;
}
}
}

void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }

bool BottomUpVec::runOnFunction(Function &F) {
Change = false;
// TODO: Start from innermost BBs first
for (auto &BB : F) {
// TODO: Replace with proper SeedCollector function.
auto Seeds = collectSeeds(BB);
// TODO: Slice Seeds into smaller chunks.
if (Seeds.size() >= 2)
tryVectorize(Seeds);
}
return Change;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ set(LLVM_LINK_COMPONENTS

add_llvm_unittest(SandboxVectorizerTests
DependencyGraphTest.cpp
LegalityTest.cpp
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//===- LegalityTest.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/SandboxIR/SandboxIR.h"
#include "llvm/Support/SourceMgr.h"
#include "gtest/gtest.h"

using namespace llvm;

struct LegalityTest : public testing::Test {
LLVMContext C;
std::unique_ptr<Module> M;

void parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
M = parseAssemblyString(IR, Err, C);
if (!M)
Err.print("LegalityTest", errs());
}
};

TEST_F(LegalityTest, Legality) {
parseIR(C, R"IR(
define void @foo(ptr %ptr) {
%gep0 = getelementptr float, ptr %ptr, i32 0
%gep1 = getelementptr float, ptr %ptr, i32 1
%ld0 = load float, ptr %gep0
%ld1 = load float, ptr %gep0
store float %ld0, ptr %gep0
store float %ld1, ptr %gep1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
[[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
[[maybe_unused]] auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
[[maybe_unused]] auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);

sandboxir::LegalityAnalysis Legality;
auto Result = Legality.canVectorize({St0, St1});
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
}

0 comments on commit 42c5a30

Please sign in to comment.