diff --git a/include/ttmlir/Scheduler/PrecedenceScheduler.h b/include/ttmlir/Scheduler/PrecedenceScheduler.h index ae4174695..c40c3be63 100644 --- a/include/ttmlir/Scheduler/PrecedenceScheduler.h +++ b/include/ttmlir/Scheduler/PrecedenceScheduler.h @@ -12,7 +12,7 @@ namespace mlir::tt::scheduler { class PrecedenceScheduler : public Scheduler { public: // Constructor taking an MLIR Operation (or a module) - PrecedenceScheduler(func::FuncOp *root) : Scheduler(root) {}; + PrecedenceScheduler(func::FuncOp *root); // Copy constructor PrecedenceScheduler(const PrecedenceScheduler &scheduler) @@ -34,6 +34,12 @@ class PrecedenceScheduler : public Scheduler { // Map of precedence llvm::DenseMap> precedence; + // Output op of the function + mlir::Operation *outputOp; + + // DFS schedule construction based on a precedence map + llvm::DenseSet visitedOps; + void constructSchedule(mlir::Operation *op); }; } // namespace mlir::tt::scheduler diff --git a/include/ttmlir/Scheduler/QueueScheduler.h b/include/ttmlir/Scheduler/QueueScheduler.h index 44d351857..d3471d997 100644 --- a/include/ttmlir/Scheduler/QueueScheduler.h +++ b/include/ttmlir/Scheduler/QueueScheduler.h @@ -28,10 +28,6 @@ class QueueScheduler : public Scheduler { // Method to take a snapshot of the scheduler std::unique_ptr snapshot() final; - -private: - // Method to check if an operation can be scheduled - bool canSchedule(mlir::Operation *op); }; } // namespace mlir::tt::scheduler diff --git a/include/ttmlir/Scheduler/Scheduler.h b/include/ttmlir/Scheduler/Scheduler.h index ea88f875a..0fcde8a20 100644 --- a/include/ttmlir/Scheduler/Scheduler.h +++ b/include/ttmlir/Scheduler/Scheduler.h @@ -42,6 +42,9 @@ class Scheduler { // Method to check if there are unscheduled operations bool hasUnscheduledOps() const; + // Method to check if an operation can be scheduled + bool canSchedule(mlir::Operation *op); + protected: // Map of dependencies llvm::DenseMap> @@ -49,7 +52,7 @@ class Scheduler { // Sets of unscheduled / schedulable / scheduled operations llvm::DenseSet unscheduledOps; - llvm::SmallVector schedulableOps; + llvm::DenseSet schedulableOps; llvm::DenseSet scheduledOps; // Operation schedule in order of execution diff --git a/lib/Scheduler/PrecedenceScheduler.cpp b/lib/Scheduler/PrecedenceScheduler.cpp index c18f1cdd3..c161ee7fd 100644 --- a/lib/Scheduler/PrecedenceScheduler.cpp +++ b/lib/Scheduler/PrecedenceScheduler.cpp @@ -6,18 +6,58 @@ namespace mlir::tt::scheduler { -void PrecedenceScheduler::scheduleOp(mlir::Operation *op) { return; } +PrecedenceScheduler::PrecedenceScheduler(func::FuncOp *root) : Scheduler(root) { + root->walk([&](mlir::Operation *op) { + if (op->hasTrait()) { + outputOp = op; + } + }); +} + +void PrecedenceScheduler::scheduleOp(mlir::Operation *op) { + unscheduledOps.erase(op); + schedulableOps.erase(op); + scheduledOps.insert(op); + + OpResult result = op->getResult(0); + for (mlir::Operation *use : result.getUsers()) { + precedence[use].push_back(op); + + // Check the schedulability of the user op after scheduling the current op + // + if (canSchedule(use)) { + schedulableOps.insert(use); + } + } +} llvm::SmallVector PrecedenceScheduler::getScheduleableOps() { - return {}; + return llvm::SmallVector(schedulableOps.begin(), + schedulableOps.end()); } llvm::SmallVector PrecedenceScheduler::getSchedule() { - return {}; + constructSchedule(outputOp); + return schedule; } std::unique_ptr PrecedenceScheduler::snapshot() { return std::make_unique(*this); } +void PrecedenceScheduler::constructSchedule(mlir::Operation *op) { + // Schedule all the precedents of the current operation + // + for (mlir::Operation *precedent : precedence[op]) { + if (!visitedOps.count(precedent)) { + constructSchedule(precedent); + } + } + + // Schedule the current operation + // + visitedOps.insert(op); + schedule.push_back(op); +} + } // namespace mlir::tt::scheduler diff --git a/lib/Scheduler/QueueScheduler.cpp b/lib/Scheduler/QueueScheduler.cpp index ad29e1499..68c0bba5b 100644 --- a/lib/Scheduler/QueueScheduler.cpp +++ b/lib/Scheduler/QueueScheduler.cpp @@ -31,14 +31,4 @@ std::unique_ptr QueueScheduler::snapshot() { return std::make_unique(*this); } -bool QueueScheduler::canSchedule(mlir::Operation *op) { - for (mlir::Operation *dep : dependencies[op]) { - if (!scheduledOps.count(dep)) { - return false; - } - } - - return true; -} - } // namespace mlir::tt::scheduler diff --git a/lib/Scheduler/Scheduler.cpp b/lib/Scheduler/Scheduler.cpp index ec0bad28b..451c2b2e3 100644 --- a/lib/Scheduler/Scheduler.cpp +++ b/lib/Scheduler/Scheduler.cpp @@ -52,6 +52,13 @@ Scheduler::Scheduler(func::FuncOp *func) { } } } + + // Find the schedulable ops + for (const auto &entry : dependencies) { + if (entry.second.empty()) { + schedulableOps.insert(entry.first); + } + } } Scheduler::Scheduler(const Scheduler &scheduler) @@ -62,4 +69,14 @@ Scheduler::Scheduler(const Scheduler &scheduler) bool Scheduler::hasUnscheduledOps() const { return !unscheduledOps.empty(); } +bool Scheduler::canSchedule(mlir::Operation *op) { + for (mlir::Operation *dep : dependencies[op]) { + if (!scheduledOps.count(dep)) { + return false; + } + } + + return true; +} + } // namespace mlir::tt::scheduler diff --git a/test/unittests/TestScheduler/TestScheduler.cpp b/test/unittests/TestScheduler/TestScheduler.cpp index 80a2f4954..29b3c16ad 100644 --- a/test/unittests/TestScheduler/TestScheduler.cpp +++ b/test/unittests/TestScheduler/TestScheduler.cpp @@ -18,6 +18,7 @@ #include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" +#include "ttmlir/Scheduler/PrecedenceScheduler.h" #include "ttmlir/Scheduler/QueueScheduler.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" @@ -244,3 +245,25 @@ TEST_F(SchedulerBase, VerifyFork) { scheduler.scheduleOp(scheduleableOps[0]); ASSERT_FALSE(scheduler.hasUnscheduledOps()); } + +// This tests the precedenceScheduler with a single operation +TEST_F(SchedulerBase, SingleOpPrecedenceScheduler) { + mlir::Value dest = createEmptyTensor(); + mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); + mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); + + mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints()); + + // First operation has arg1 and arg2 and %0 as dps operand + ttir::TTIROp op = builder.create(builder.getUnknownLoc(), lhs, + rhs, dest, attrs); + + mlir::tt::scheduler::PrecedenceScheduler scheduler(&func); + ASSERT_TRUE(scheduler.hasUnscheduledOps()); + llvm::SmallVector scheduleableOps = + scheduler.getScheduleableOps(); + ASSERT_EQ(scheduleableOps.size(), 1); + scheduler.scheduleOp(scheduleableOps[0]); + ASSERT_FALSE(scheduler.hasUnscheduledOps()); + ASSERT_EQ(scheduleableOps[0], op.getOperation()); +}