From ccb2aaa4046cb08a0aa7d852847a41780782421c Mon Sep 17 00:00:00 2001 From: eedalong Date: Thu, 23 May 2024 16:09:09 +0800 Subject: [PATCH] minor fixes --- .../mlir/disc/transforms/disc_op_schedule.cc | 18 +++++++++--------- .../transforms/lhlo_legalize_roots_to_loops.cc | 6 ++++-- 2 files changed, 13 insertions(+), 11 deletions(-) mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/disc_op_schedule.cc mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc diff --git a/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc b/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc old mode 100644 new mode 100755 index 36461586229..54c4957f667 --- a/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc +++ b/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc @@ -354,7 +354,7 @@ class ScheduleGraph { explicit ScheduleGraph(std::vector& post_order_instructions, LatencyEstimator* latency_estimator, AsyncTracker* async_tracker) { - InitilizeGrpahTopology(post_order_instructions, latency_estimator, + InitilizeGraphTopology(post_order_instructions, latency_estimator, async_tracker); InitializeGraphAnalysis(latency_estimator, async_tracker); } @@ -497,7 +497,7 @@ class ScheduleGraph { } } - void InitilizeGrpahTopology(std::vector& post_order_instructions, + void InitilizeGraphTopology(std::vector& post_order_instructions, LatencyEstimator* latency_estimator, AsyncTracker* async_tracker) { original_order_ = post_order_instructions; @@ -957,7 +957,6 @@ struct DiscOpSchedulePass : public DiscOpSchedulePassBase { return; } - bool need_schedule = false; // Initialization latency_estimator_ = new LatencyEstimator(); async_tracker_ = new AsyncTracker(scheduler_config_); @@ -966,20 +965,21 @@ struct DiscOpSchedulePass : public DiscOpSchedulePassBase { for (auto& block : main_func.getBody()) { for (auto& op : block) { original_op_sequence.push_back(&op); - need_schedule = - need_schedule || async_tracker_->IsSupportedAsyncDone(&op); } } - if (!need_schedule) { - return; - } - scheduler_core_ = new SchedulerCore(latency_estimator_, async_tracker_, scheduler_config_); auto scheduled_op_sequence = scheduler_core_->ScheduleComputation(original_op_sequence); + if (scheduled_op_sequence.size() != original_op_sequence.size()) { + main_func.emitError( + "Schedule Pass error, scheduled_op_sequence should have the same num " + "of ops with original_op_sequence"); + signalPassFailure(); + return; + } for (auto& block : main_func.getBody()) { for (int op_idx = 0; op_idx < scheduled_op_sequence.size(); op_idx++) { scheduled_op_sequence[op_idx]->moveBefore(&block.front()); diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc old mode 100644 new mode 100755 index 73b1feaaa68..93789480e5c --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -5712,9 +5712,11 @@ struct DiscLhloLegalizeRootsToParallelLoops // TODO: We should put even single nodes into a fusion by fusion pass // Revisit this and walk lmhlo::FusionOp only after the revision done. func.walk([&](lmhlo::LmhloOp op) { - // Skip the embedded ops in lmhlo.fusion or lmhlo.reduce/scatter + // Skip the embedded ops in lmhlo.fusion or lmhlo.reduce/scatter or + // lmhlo_disc.args_mutation lmhlo::LmhloOp parent = op->getParentOfType(); - if (parent && !isa(op)) { + if (isa(op) || + parent && !isa(op)) { return; } if (isFusionType(op) &&