Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed May 23, 2024
1 parent 6613d88 commit ccb2aaa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
18 changes: 9 additions & 9 deletions tao_compiler/mlir/disc/transforms/disc_op_schedule.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class ScheduleGraph {
explicit ScheduleGraph(std::vector<Operation*>& post_order_instructions,
LatencyEstimator* latency_estimator,
AsyncTracker* async_tracker) {
InitilizeGrpahTopology(post_order_instructions, latency_estimator,
InitilizeGraphTopology(post_order_instructions, latency_estimator,
async_tracker);
InitializeGraphAnalysis(latency_estimator, async_tracker);
}
Expand Down Expand Up @@ -497,7 +497,7 @@ class ScheduleGraph {
}
}

void InitilizeGrpahTopology(std::vector<Operation*>& post_order_instructions,
void InitilizeGraphTopology(std::vector<Operation*>& post_order_instructions,
LatencyEstimator* latency_estimator,
AsyncTracker* async_tracker) {
original_order_ = post_order_instructions;
Expand Down Expand Up @@ -957,7 +957,6 @@ struct DiscOpSchedulePass : public DiscOpSchedulePassBase<DiscOpSchedulePass> {
return;
}

bool need_schedule = false;
// Initialization
latency_estimator_ = new LatencyEstimator();
async_tracker_ = new AsyncTracker(scheduler_config_);
Expand All @@ -966,20 +965,21 @@ struct DiscOpSchedulePass : public DiscOpSchedulePassBase<DiscOpSchedulePass> {
for (auto& block : main_func.getBody()) {
for (auto& op : block) {
original_op_sequence.push_back(&op);
need_schedule =
need_schedule || async_tracker_->IsSupportedAsyncDone(&op);
}
}

if (!need_schedule) {
return;
}

scheduler_core_ = new SchedulerCore(latency_estimator_, async_tracker_,
scheduler_config_);
auto scheduled_op_sequence =
scheduler_core_->ScheduleComputation(original_op_sequence);

if (scheduled_op_sequence.size() != original_op_sequence.size()) {
main_func.emitError(
"Schedule Pass error, scheduled_op_sequence should have the same num "
"of ops with original_op_sequence");
signalPassFailure();
return;
}
for (auto& block : main_func.getBody()) {
for (int op_idx = 0; op_idx < scheduled_op_sequence.size(); op_idx++) {
scheduled_op_sequence[op_idx]->moveBefore(&block.front());
Expand Down
6 changes: 4 additions & 2 deletions tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5712,9 +5712,11 @@ struct DiscLhloLegalizeRootsToParallelLoops
// TODO: We should put even single nodes into a fusion by fusion pass
// Revisit this and walk lmhlo::FusionOp only after the revision done.
func.walk([&](lmhlo::LmhloOp op) {
// Skip the embedded ops in lmhlo.fusion or lmhlo.reduce/scatter
// Skip the embedded ops in lmhlo.fusion or lmhlo.reduce/scatter or
// lmhlo_disc.args_mutation
lmhlo::LmhloOp parent = op->getParentOfType<lmhlo::LmhloOp>();
if (parent && !isa<lmhlo::FusionOp>(op)) {
if (isa<lmhlo_disc::ArgsMutationOp>(op) ||
parent && !isa<lmhlo::FusionOp>(op)) {
return;
}
if (isFusionType<FusionType::kStitch>(op) &&
Expand Down

0 comments on commit ccb2aaa

Please sign in to comment.