Skip to content

Commit

Permalink
PrecedenceScheduler implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT committed Nov 7, 2024
1 parent 542797b commit 7ace2d9
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 19 deletions.
8 changes: 7 additions & 1 deletion include/ttmlir/Scheduler/PrecedenceScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace mlir::tt::scheduler {
class PrecedenceScheduler : public Scheduler {
public:
// Constructor taking an MLIR Operation (or a module)
PrecedenceScheduler(func::FuncOp *root) : Scheduler(root) {};
PrecedenceScheduler(func::FuncOp *root);

// Copy constructor
PrecedenceScheduler(const PrecedenceScheduler &scheduler)
Expand All @@ -34,6 +34,12 @@ class PrecedenceScheduler : public Scheduler {
// Map of precedence
llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
precedence;
// Output op of the function
mlir::Operation *outputOp;

// DFS schedule construction based on a precedence map
llvm::DenseSet<mlir::Operation *> visitedOps;
void constructSchedule(mlir::Operation *op);
};

} // namespace mlir::tt::scheduler
Expand Down
4 changes: 0 additions & 4 deletions include/ttmlir/Scheduler/QueueScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ class QueueScheduler : public Scheduler {

// Method to take a snapshot of the scheduler
std::unique_ptr<Scheduler> snapshot() final;

private:
// Method to check if an operation can be scheduled
bool canSchedule(mlir::Operation *op);
};

} // namespace mlir::tt::scheduler
Expand Down
5 changes: 4 additions & 1 deletion include/ttmlir/Scheduler/Scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,17 @@ class Scheduler {
// Method to check if there are unscheduled operations
bool hasUnscheduledOps() const;

// Method to check if an operation can be scheduled
bool canSchedule(mlir::Operation *op);

protected:
// Map of dependencies
llvm::DenseMap<mlir::Operation *, llvm::SmallVector<mlir::Operation *>>
dependencies;

// Sets of unscheduled / schedulable / scheduled operations
llvm::DenseSet<mlir::Operation *> unscheduledOps;
llvm::SmallVector<mlir::Operation *> schedulableOps;
llvm::DenseSet<mlir::Operation *> schedulableOps;
llvm::DenseSet<mlir::Operation *> scheduledOps;

// Operation schedule in order of execution
Expand Down
46 changes: 43 additions & 3 deletions lib/Scheduler/PrecedenceScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,58 @@

namespace mlir::tt::scheduler {

void PrecedenceScheduler::scheduleOp(mlir::Operation *op) { return; }
PrecedenceScheduler::PrecedenceScheduler(func::FuncOp *root) : Scheduler(root) {
root->walk([&](mlir::Operation *op) {
if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
outputOp = op;
}
});
}

void PrecedenceScheduler::scheduleOp(mlir::Operation *op) {
unscheduledOps.erase(op);
schedulableOps.erase(op);
scheduledOps.insert(op);

OpResult result = op->getResult(0);
for (mlir::Operation *use : result.getUsers()) {
precedence[use].push_back(op);

// Check the schedulability of the user op after scheduling the current op
//
if (canSchedule(use)) {
schedulableOps.insert(use);
}
}
}

llvm::SmallVector<mlir::Operation *> PrecedenceScheduler::getScheduleableOps() {
return {};
return llvm::SmallVector<mlir::Operation *>(schedulableOps.begin(),
schedulableOps.end());
}

llvm::SmallVector<mlir::Operation *> PrecedenceScheduler::getSchedule() {
return {};
constructSchedule(outputOp);
return schedule;
}

std::unique_ptr<Scheduler> PrecedenceScheduler::snapshot() {
return std::make_unique<PrecedenceScheduler>(*this);
}

void PrecedenceScheduler::constructSchedule(mlir::Operation *op) {
// Schedule all the precedents of the current operation
//
for (mlir::Operation *precedent : precedence[op]) {
if (!visitedOps.count(precedent)) {
constructSchedule(precedent);
}
}

// Schedule the current operation
//
visitedOps.insert(op);
schedule.push_back(op);
}

} // namespace mlir::tt::scheduler
10 changes: 0 additions & 10 deletions lib/Scheduler/QueueScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,4 @@ std::unique_ptr<Scheduler> QueueScheduler::snapshot() {
return std::make_unique<QueueScheduler>(*this);
}

bool QueueScheduler::canSchedule(mlir::Operation *op) {
for (mlir::Operation *dep : dependencies[op]) {
if (!scheduledOps.count(dep)) {
return false;
}
}

return true;
}

} // namespace mlir::tt::scheduler
17 changes: 17 additions & 0 deletions lib/Scheduler/Scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ Scheduler::Scheduler(func::FuncOp *func) {
}
}
}

// Find the schedulable ops
for (const auto &entry : dependencies) {
if (entry.second.empty()) {
schedulableOps.insert(entry.first);
}
}
}

Scheduler::Scheduler(const Scheduler &scheduler)
Expand All @@ -62,4 +69,14 @@ Scheduler::Scheduler(const Scheduler &scheduler)

bool Scheduler::hasUnscheduledOps() const { return !unscheduledOps.empty(); }

bool Scheduler::canSchedule(mlir::Operation *op) {
for (mlir::Operation *dep : dependencies[op]) {
if (!scheduledOps.count(dep)) {
return false;
}
}

return true;
}

} // namespace mlir::tt::scheduler
23 changes: 23 additions & 0 deletions test/unittests/TestScheduler/TestScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Scheduler/PrecedenceScheduler.h"
#include "ttmlir/Scheduler/QueueScheduler.h"

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
Expand Down Expand Up @@ -244,3 +245,25 @@ TEST_F(SchedulerBase, VerifyFork) {
scheduler.scheduleOp(scheduleableOps[0]);
ASSERT_FALSE(scheduler.hasUnscheduledOps());
}

// This tests the precedenceScheduler with a single operation
TEST_F(SchedulerBase, SingleOpPrecedenceScheduler) {
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1);

mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints());

// First operation has arg1 and arg2 and %0 as dps operand
ttir::TTIROp op = builder.create<ttir::AddOp>(builder.getUnknownLoc(), lhs,
rhs, dest, attrs);

mlir::tt::scheduler::PrecedenceScheduler scheduler(&func);
ASSERT_TRUE(scheduler.hasUnscheduledOps());
llvm::SmallVector<mlir::Operation *> scheduleableOps =
scheduler.getScheduleableOps();
ASSERT_EQ(scheduleableOps.size(), 1);
scheduler.scheduleOp(scheduleableOps[0]);
ASSERT_FALSE(scheduler.hasUnscheduledOps());
ASSERT_EQ(scheduleableOps[0], op.getOperation());
}

0 comments on commit 7ace2d9

Please sign in to comment.